Spaces:
Runtime error
Runtime error
| """ | |
| EM Embedded - Simulation Module | |
| Contains simulation logic including run_simulation_only, reset_to_defaults, | |
| and stop handlers. | |
| """ | |
| import re | |
| import asyncio | |
| import threading | |
| import time | |
| import numpy as np | |
| from .state import state, ctrl, _apply_workflow_highlights, is_statevector_estimator_selected, is_ibm_qpu_selected | |
| from .globals import ( | |
| plotter, | |
| simulation_data, | |
| current_mesh, | |
| snapshot_times, | |
| stop_simulation, | |
| qpu_ts_cache, | |
| sim_ts_cache, | |
| set_stop_simulation, | |
| reset_globals, | |
| ) | |
| # Import backend functions | |
| try: | |
| from quantum.utils.delta_impulse_generator import ( | |
| create_impulse_state, create_gaussian_state, | |
| create_impulse_state_from_pos, create_gaussian_state_from_pos, | |
| run_sim, create_time_frames | |
| ) | |
| import quantum.utils.delta_impulse_generator as qutils | |
| except ModuleNotFoundError: | |
| from utils.delta_impulse_generator import ( | |
| create_impulse_state, create_gaussian_state, | |
| create_impulse_state_from_pos, create_gaussian_state_from_pos, | |
| run_sim, create_time_frames | |
| ) | |
| import utils.delta_impulse_generator as qutils | |
| # --- Module-level async infrastructure --- | |
| _heartbeat_thread = None | |
| _heartbeat_on = False | |
| _sim_start_time = None | |
| _simulation_executor = None # Thread pool for async execution | |
| _main_loop = None # Reference to main event loop for thread-safe callbacks | |
| def _get_server(): | |
| """Get the trame server from state module.""" | |
| from .state import get_server | |
| return get_server() | |
| def _flush_state(): | |
| """Force state flush to browser (synchronous, for main thread use).""" | |
| try: | |
| server = _get_server() | |
| if server: | |
| server.state.flush() | |
| except Exception: | |
| pass | |
| def _flush_state_threadsafe(): | |
| """ | |
| Thread-safe state flush - schedules flush on the main event loop. | |
| Use this from background threads (e.g., inside executor callbacks). | |
| """ | |
| global _main_loop | |
| try: | |
| server = _get_server() | |
| if server and _main_loop is not None and _main_loop.is_running(): | |
| # Schedule the flush on the main event loop | |
| _main_loop.call_soon_threadsafe(server.state.flush) | |
| elif server: | |
| # Fallback: direct flush (may not work from threads) | |
| server.state.flush() | |
| except Exception: | |
| pass | |
| async def _flush_async(): | |
| """Async helper to flush state and yield to event loop.""" | |
| _flush_state() | |
| await asyncio.sleep(0) # Yield control to event loop | |
| def _start_progress_heartbeat(): | |
| """Start background thread for continuous progress updates.""" | |
| global _heartbeat_thread, _heartbeat_on, _sim_start_time | |
| if _heartbeat_thread and _heartbeat_thread.is_alive(): | |
| return | |
| _sim_start_time = time.time() | |
| def loop_fn(): | |
| global _heartbeat_on | |
| while _heartbeat_on: | |
| if state.is_running and _sim_start_time is not None: | |
| elapsed = time.time() - _sim_start_time | |
| state.simulation_elapsed = elapsed | |
| _flush_state_threadsafe() # Use thread-safe version | |
| time.sleep(0.1) # Update every 100ms | |
| _heartbeat_on = True | |
| _heartbeat_thread = threading.Thread(target=loop_fn, daemon=True) | |
| _heartbeat_thread.start() | |
| def _stop_progress_heartbeat(): | |
| """Stop the background heartbeat thread.""" | |
| global _heartbeat_on, _heartbeat_thread | |
| _heartbeat_on = False | |
| _heartbeat_thread = None | |
| def _auto_hide_status_window(delay_seconds=3.0): | |
| """ | |
| Schedule the status window to auto-hide after a delay. | |
| Shows the completion message briefly then closes automatically. | |
| """ | |
| def _hide_after_delay(): | |
| time.sleep(delay_seconds) | |
| state.status_visible = False | |
| _flush_state_threadsafe() | |
| hide_thread = threading.Thread(target=_hide_after_delay, daemon=True) | |
| hide_thread.start() | |
| __all__ = [ | |
| "run_simulation_only", | |
| "reset_to_defaults", | |
| "stop_simulation_handler", | |
| "log_to_console", | |
| "log_message", | |
| "setup_surface_plot_data", | |
| "generate_plot", | |
| "redraw_surface_plot", | |
| "update_sim_monitor_points", | |
| "add_dotted_unit_grid", | |
| "add_dotted_unit_grid_scaled", | |
| "build_sim_timeseries_plotly", | |
| "update_value_display", | |
| ] | |
| def update_sim_monitor_points(): | |
| """Update simulator monitor points based on timeseries_points input.""" | |
| from .utils import snap_samples_to_grid | |
| sample_value = state.timeseries_points | |
| if not sample_value or not str(sample_value).strip(): | |
| state.timeseries_gridpoints = "" | |
| state.timeseries_point_info = "" | |
| return | |
| nx_val = state.nx | |
| if nx_val is None: | |
| state.timeseries_gridpoints = "" | |
| state.timeseries_point_info = "Select a grid size (nx) to compute the nearest monitor positions." | |
| return | |
| snapped, message = snap_samples_to_grid(sample_value, int(nx_val)) | |
| state.timeseries_gridpoints = snapped | |
| state.timeseries_point_info = message or "" | |
| def log_message(message, level="INFO"): | |
| """Log a message to the console.""" | |
| from datetime import datetime | |
| timestamp = datetime.now().strftime("%H:%M:%S") | |
| log_line = f"[{timestamp}] [{level}] {message}\n" | |
| current = state.console_logs or "" | |
| state.console_logs = current + log_line | |
| def log_to_console(message): | |
| """Log a message to the console output.""" | |
| current = state.console_output or "" | |
| state.console_output = current + message + "\n" | |
| def setup_surface_plot_data(sim_data, nx): | |
| """Setup surface plot data from simulation results - matches em_embedded.py exactly.""" | |
| from . import globals as g | |
| nx = int(nx) | |
| mask = np.arange(1, nx * nx + 1) % nx != 0 | |
| g.data_frames = {'Ez': [], 'Hx': [], 'Hy': []} | |
| g.surface_clims = {'Ez': [np.inf, -np.inf], 'Hx': [np.inf, -np.inf], 'Hy': [np.inf, -np.inf]} | |
| for u in sim_data: | |
| ez = u[:nx*nx].reshape(nx, nx) | |
| hx = u[2*nx*nx:3*nx*nx-nx].reshape(nx-1, nx) | |
| hy = u[-nx*nx:][mask].reshape(nx, nx-1) | |
| g.data_frames['Ez'].append(ez) | |
| g.data_frames['Hx'].append(hx) | |
| g.data_frames['Hy'].append(hy) | |
| if ez.size > 0: | |
| g.surface_clims['Ez'][0] = min(g.surface_clims['Ez'][0], ez.min()) | |
| g.surface_clims['Ez'][1] = max(g.surface_clims['Ez'][1], ez.max()) | |
| if hx.size > 0: | |
| g.surface_clims['Hx'][0] = min(g.surface_clims['Hx'][0], hx.min()) | |
| g.surface_clims['Hx'][1] = max(g.surface_clims['Hx'][1], hx.max()) | |
| if hy.size > 0: | |
| g.surface_clims['Hy'][0] = min(g.surface_clims['Hy'][0], hy.min()) | |
| g.surface_clims['Hy'][1] = max(g.surface_clims['Hy'][1], hy.max()) | |
| # Prevent zero-range clims | |
| for key in g.surface_clims: | |
| if g.surface_clims[key][0] == g.surface_clims[key][1]: | |
| g.surface_clims[key][0] -= 1e-9 | |
| g.surface_clims[key][1] += 1e-9 | |
| # Use integer grid coordinates (like em_embedded.py / app.py) | |
| x = np.arange(nx) | |
| y = np.arange(nx) | |
| x_m1 = np.arange(nx - 1) | |
| y_m1 = np.arange(nx - 1) | |
| g.X_grids['Ez'], g.Y_grids['Ez'] = np.meshgrid(x, y) | |
| g.X_grids['Hx'], g.Y_grids['Hx'] = np.meshgrid(x, y_m1) | |
| g.X_grids['Hy'], g.Y_grids['Hy'] = np.meshgrid(x_m1, y) | |
| # Compute z_scale for visualization | |
| finite_vals = [abs(float(v)) for pair in g.surface_clims.values() for v in pair if np.isfinite(v)] | |
| max_abs = max(finite_vals) if finite_vals else 1e-9 | |
| g.z_scale = (nx / 2) / max(max_abs, 1e-9) | |
| g.simulation_data = sim_data | |
| def generate_plot(): | |
| """Generate the plot based on output_type selection.""" | |
| import re | |
| from . import globals as g | |
| if not state.simulation_has_run: | |
| return | |
| plotter.clear() | |
| try: | |
| plotter.disable_picking() | |
| except Exception: | |
| pass | |
| nx = int(state.nx) | |
| if state.output_type == "Surface Plot": | |
| redraw_surface_plot() | |
| else: # Time Series -> Plotly for Simulator | |
| try: | |
| points_str = state.timeseries_gridpoints or "" | |
| positions = [tuple(map(int, match)) for match in re.findall(r'\((\d+)\s*,\s*(\d+)\)', points_str)] | |
| if not positions and (state.timeseries_points or "").strip(): | |
| raise ValueError("No valid monitor positions found. Enter (x, y) pairs in [0,1] x [0,1].") | |
| fig = build_sim_timeseries_plotly(state.timeseries_field, positions, nx, g.snapshot_times, g.simulation_data) | |
| if fig is not None: | |
| # Cache the figure for export | |
| g.sim_ts_cache["fig"] = fig | |
| g.sim_ts_cache["field"] = state.timeseries_field | |
| try: | |
| ctrl.sim_ts_update(fig) | |
| except Exception: | |
| pass | |
| except Exception as e: | |
| state.error_message = f"Plotting Error: {e}" | |
| ctrl.view_update() | |
| def redraw_surface_plot(): | |
| """Redraw the surface plot with current field and time - matches em_embedded.py.""" | |
| import pyvista as pv | |
| from . import globals as g | |
| plotter.clear() | |
| field = state.surface_field | |
| if g.data_frames is None or not g.data_frames.get(field): | |
| return | |
| if g.snapshot_times is None or len(g.snapshot_times) == 0: | |
| return | |
| # Find nearest snapshot index to requested time and clamp to available frames | |
| req_t = float(state.time_val) | |
| times = np.asarray(g.snapshot_times) | |
| idx = int(np.argmin(np.abs(times - req_t))) | |
| max_idx = len(g.data_frames[field]) - 1 | |
| idx = max(0, min(idx, max_idx)) | |
| z_data = g.data_frames[field][idx] | |
| X = g.X_grids[field] | |
| Y = g.Y_grids[field] | |
| points = np.c_[X.ravel(), Y.ravel(), z_data.ravel() * g.z_scale] | |
| poly = pv.PolyData(points) | |
| mesh = poly.delaunay_2d() | |
| mesh['scalars'] = z_data.ravel() | |
| g.current_mesh = mesh | |
| # Add mesh with styling matching em_embedded.py | |
| plotter.add_mesh( | |
| mesh, | |
| scalars='scalars', | |
| # clim=g.surface_clims[field], | |
| cmap="turbo", | |
| show_scalar_bar=False, | |
| show_edges=True, | |
| edge_color='grey', | |
| line_width=0.5 | |
| ) | |
| plotter.add_scalar_bar(title=f"{field} Amplitude") | |
| # Enable point picking | |
| try: | |
| plotter.disable_picking() | |
| except Exception: | |
| pass | |
| plotter.enable_point_picking(callback=update_value_display, show_message=False) | |
| plotter.add_axes() | |
| plotter.view_isometric() | |
| try: | |
| plotter.camera.parallel_projection = True | |
| except Exception: | |
| pass | |
| ctrl.view_update() | |
| # --------------------------------------------------------------------------- | |
| # Async Simulation Runner with Full Async Pattern | |
| # --------------------------------------------------------------------------- | |
| def run_simulation_only(): | |
| """ | |
| Entry point for simulation - launches the async worker. | |
| This is called by the UI button click and schedules the async task. | |
| """ | |
| server = _get_server() | |
| if server is None: | |
| log_to_console("Error: Server not available") | |
| return | |
| # Schedule the async simulation | |
| asyncio.ensure_future(_run_simulation_async()) | |
| async def _run_simulation_async(): | |
| """ | |
| Async simulation runner that uses thread pool for blocking work. | |
| This allows the UI to update in real-time during simulation. | |
| """ | |
| global _main_loop | |
| from . import globals as g | |
| from .excitation import nearest_node_index | |
| from .qpu import build_qpu_timeseries_plotly_multi | |
| from concurrent.futures import ThreadPoolExecutor | |
| # Capture the main event loop for thread-safe callbacks | |
| _main_loop = asyncio.get_event_loop() | |
| # Create executor for blocking operations | |
| executor = ThreadPoolExecutor(max_workers=1) | |
| loop = _main_loop | |
| # Require selections before running | |
| if not state.geometry_selection: | |
| state.error_message = "Please select a geometry before running the simulation." | |
| log_to_console("Error: Please select a geometry before running.") | |
| state.status_visible = True | |
| state.status_message = "Error: Please select a geometry before running." | |
| state.status_type = "error" | |
| state.show_progress = False | |
| state.is_running = False | |
| state.run_button_text = "RUN!" | |
| await _flush_async() | |
| return | |
| if not state.dist_type: | |
| state.error_message = "Please select an initial state before running the simulation." | |
| log_to_console("Error: Please select an initial state before running.") | |
| state.status_visible = True | |
| state.status_message = "Error: Please select an initial state before running." | |
| state.status_type = "error" | |
| state.show_progress = False | |
| state.is_running = False | |
| state.run_button_text = "RUN!" | |
| await _flush_async() | |
| return | |
| # Show status: Starting simulation | |
| state.status_visible = True | |
| state.status_message = "Initializing simulation..." | |
| log_to_console("Initializing simulation...") | |
| state.status_type = "info" | |
| state.show_progress = True | |
| state.simulation_progress = 0 | |
| await _flush_async() | |
| # Start heartbeat for continuous elapsed time updates | |
| _start_progress_heartbeat() | |
| # Progress callback that updates state (called from worker thread) | |
| # Uses thread-safe flush to push updates to browser | |
| last_logged_percent = [0] | |
| def _progress_callback(percent): | |
| state.simulation_progress = percent | |
| if percent - last_logged_percent[0] >= 10: | |
| log_to_console(f"Simulation progress: {int(percent)}%") | |
| last_logged_percent[0] = percent | |
| _flush_state_threadsafe() # Thread-safe flush! | |
| # Reset stop flag and enable Stop button at start | |
| set_stop_simulation(False) | |
| state.stop_button_disabled = False | |
| plotter.clear() | |
| g.current_mesh = None | |
| state.error_message = "" | |
| state.is_running = True | |
| state.simulation_has_run = False | |
| state.run_button_text = "Running" | |
| # Initial flush to show "Running" state | |
| _flush_state() | |
| nx, T = int(state.nx), float(state.T) | |
| na, R = 1, 4 | |
| try: | |
| state.status_message = "Creating initial state..." | |
| state.simulation_progress = 10 | |
| _flush_state() | |
| if state.dist_type == "Delta": | |
| initial_state = create_impulse_state_from_pos( | |
| (nx, nx), | |
| (float(state.impulse_x), float(state.impulse_y)), | |
| snap_to_grid=True, | |
| ) | |
| else: | |
| initial_state = create_gaussian_state_from_pos( | |
| (nx, nx), | |
| (float(state.mu_x), float(state.mu_y)), | |
| (float(state.sigma_x), float(state.sigma_y)), | |
| snap_to_grid=True, | |
| ) | |
| except ValueError as e: | |
| state.error_message = f"Initial State Error: {e}" | |
| state.status_message = f"Error: {e}" | |
| state.status_type = "error" | |
| state.show_progress = False | |
| state.is_running = False | |
| state.run_button_text = "RUN!" | |
| state.stop_button_disabled = True | |
| _stop_progress_heartbeat() | |
| await _flush_async() | |
| executor.shutdown(wait=False) | |
| return | |
| sve_selected = is_statevector_estimator_selected() | |
| # If Statevector Estimator selected, build time series chart and return | |
| if sve_selected: | |
| try: | |
| log_to_console("Running Statevector Estimator...") | |
| state.status_message = "Step 1: Initializing Statevector Estimator..." | |
| state.simulation_progress = 5 | |
| await _flush_async() | |
| state.qpu_ts_ready = False | |
| state.qpu_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" | |
| state.qpu_ts_other_ready = False | |
| state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" | |
| # Inputs for QPU | |
| snapshot_dt = float(state.dt_user) | |
| ix_imp, iy_imp = nearest_node_index(float(state.impulse_x), float(state.impulse_y), nx) | |
| impulse_pos = (ix_imp, iy_imp) | |
| # Build configs from primitive slots | |
| configs = [{ | |
| "field": (state.qpu_field_components or "Ez"), | |
| "points": (state.qpu_monitor_gridpoints or ""), | |
| }] | |
| try: | |
| cnt = int(state.qpu_monitor_count or 0) | |
| except Exception: | |
| cnt = 0 | |
| for slot_num in range(2, 2 + cnt): | |
| f = getattr(state, f"qpu_field_components_{slot_num}", "Ez") or "Ez" | |
| p = getattr(state, f"qpu_monitor_gridpoints_{slot_num}", "") or "" | |
| configs.append({"field": f, "points": p}) | |
| state.status_message = "Step 1: Setting up Statevector Estimator..." | |
| state.simulation_progress = 10 | |
| await _flush_async() | |
| # SVE-specific progress callback that maps internal 0-100% to 10-90% range | |
| # and shows appropriate step messages | |
| def _sve_progress_callback(pct): | |
| # Map internal progress (0-100%) to range 10-90% | |
| mapped_pct = 10 + (pct * 0.8) # 10% to 90% | |
| state.simulation_progress = int(mapped_pct) | |
| if mapped_pct < 30: | |
| state.status_message = f"Step 2: Building quantum circuits ({int(mapped_pct)}%)" | |
| elif mapped_pct < 70: | |
| state.status_message = f"Step 3: Running Statevector simulation ({int(mapped_pct)}%)" | |
| else: | |
| state.status_message = f"Step 4: Processing results ({int(mapped_pct)}%)" | |
| _flush_state_threadsafe() | |
| def _sve_series_runner(field_type, positions, total_time, snapshot_dt, nx, impulse_pos, progress_callback=None, print_callback=None): | |
| return qutils.run_sve( | |
| field_type, | |
| positions, | |
| None, | |
| total_time, | |
| snapshot_dt, | |
| nx, | |
| None, | |
| impulse_pos, | |
| progress_callback=progress_callback, | |
| print_callback=print_callback, | |
| ) | |
| # Run SVE in executor to keep UI responsive | |
| def _run_sve_blocking(): | |
| return build_qpu_timeseries_plotly_multi( | |
| configs, nx, T, snapshot_dt, impulse_pos, | |
| series_runner=_sve_series_runner, | |
| progress_callback=_sve_progress_callback, | |
| print_callback=log_to_console | |
| ) | |
| fig = await loop.run_in_executor(executor, _run_sve_blocking) | |
| qpu_ts_cache["fig"] = fig | |
| # Step 5: Creating plots (90-100%) | |
| state.simulation_progress = 95 | |
| state.status_message = "Step 5: Creating plots (95%)" | |
| _flush_state() | |
| try: | |
| ctrl.qpu_ts_update(fig) | |
| except Exception: | |
| pass | |
| state.simulation_has_run = True | |
| state.run_button_text = "Successful!" | |
| state.simulation_progress = 100 | |
| state.status_message = "Statevector Estimator simulation completed successfully!" | |
| log_to_console("Statevector Estimator run completed") | |
| state.status_type = "success" | |
| state.show_progress = False | |
| _auto_hide_status_window(3.0) # Auto-hide after 3 seconds | |
| ready = bool(getattr(fig, "data", None)) and len(fig.data) > 0 | |
| state.qpu_ts_ready = ready | |
| state.qpu_plot_style = ( | |
| "width: 900px; height: 660px; margin: 0 auto;" | |
| if ready else "display: none; width: 900px; height: 660px; margin: 0 auto;" | |
| ) | |
| state.qpu_ts_other_ready = False | |
| state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" | |
| if not ready: | |
| state.error_message = "No Statevector Estimator time series generated. Check Δt, T, nx, and monitor points." | |
| state.status_message = "Warning: No Statevector Estimator time series generated." | |
| state.status_type = "warning" | |
| log_to_console("Statevector Estimator complete.") | |
| except Exception as e: | |
| state.error_message = f"Statevector Estimator run failed: {e}" | |
| state.status_message = f"Statevector Estimator Error: {e}" | |
| state.status_type = "error" | |
| state.show_progress = False | |
| state.run_button_text = "RUN!" | |
| state.qpu_ts_ready = False | |
| log_to_console(f"Statevector Estimator error: {e}") | |
| finally: | |
| state.is_running = False | |
| state.stop_button_disabled = True | |
| _stop_progress_heartbeat() | |
| await _flush_async() | |
| executor.shutdown(wait=False) | |
| return | |
| # IBM QPU branch | |
| ibm_qpu_selected = is_ibm_qpu_selected() | |
| if ibm_qpu_selected: | |
| try: | |
| log_to_console("Running IBM QPU simulation...") | |
| state.status_message = "Running IBM QPU simulation..." | |
| state.simulation_progress = 5 | |
| await _flush_async() | |
| # Import IBM QPU backend | |
| try: | |
| from quantum.utils.EBU_Quantum.no_body.base_functions import get_field_values as ibm_get_field_values, create_time_frames as ibm_create_time_frames | |
| except ModuleNotFoundError: | |
| from utils.EBU_Quantum.no_body.base_functions import get_field_values as ibm_get_field_values, create_time_frames as ibm_create_time_frames | |
| # Inputs for IBM QPU (single field, single position only!) | |
| snapshot_dt = float(state.dt_user) | |
| ix_imp, iy_imp = nearest_node_index(float(state.impulse_x), float(state.impulse_y), nx) | |
| impulse_pos = (ix_imp, iy_imp) | |
| # Get field and single position from UI | |
| # IBM QPU only supports one field and one position! | |
| field_type = (state.qpu_field_components or "Ez").strip() | |
| if field_type == "All": | |
| field_type = "Ez" # Default to Ez if 'All' selected (not supported by IBM QPU) | |
| log_to_console("Warning: IBM QPU only supports single field. Defaulting to Ez.") | |
| # Parse single monitor position | |
| pts_str = str(state.qpu_monitor_gridpoints or "").strip() | |
| raw_pts = [tuple(map(int, m)) for m in re.findall(r"\((\d+)\s*,\s*(\d+)\)", pts_str)] | |
| if not raw_pts: | |
| # Default to impulse position | |
| monitor_x, monitor_y = impulse_pos | |
| log_to_console(f"No monitor position specified. Using impulse position ({monitor_x}, {monitor_y}).") | |
| else: | |
| # Use only the first position (IBM QPU restriction) | |
| monitor_x, monitor_y = raw_pts[0] | |
| if len(raw_pts) > 1: | |
| log_to_console(f"Warning: IBM QPU only supports single position. Using first: ({monitor_x}, {monitor_y})") | |
| state.status_message = "Step 1: Generating circuit..." | |
| state.simulation_progress = 0 | |
| await _flush_async() | |
| def _ibm_progress_callback(pct, message=None): | |
| """ | |
| Progress callback for IBM QPU with 4-step pattern: | |
| Step 1: Generating circuit (0-10%) | |
| Step 2: Optimising Circuit (10-60%) | |
| Step 3: Job Submitted + Status monitoring (60-90%) | |
| Step 4: Creating Plots (90-100%) | |
| """ | |
| state.simulation_progress = int(pct) | |
| if message: | |
| state.status_message = message | |
| elif pct < 10: | |
| state.status_message = f"Step 1: Generating circuit ({int(pct)}%)" | |
| elif pct < 60: | |
| # Map 10-40% internal to 10-60% display | |
| state.status_message = f"Step 2: Optimising circuit ({int(pct)}%)" | |
| elif pct < 90: | |
| state.status_message = f"Step 3: Job execution ({int(pct)}%)" | |
| else: | |
| state.status_message = f"Step 4: Creating plots ({int(pct)}%)" | |
| _flush_state_threadsafe() # Thread-safe flush from callback thread | |
| # Call the IBM QPU get_field_values function in executor to keep UI responsive | |
| def _run_ibm_qpu(): | |
| return ibm_get_field_values( | |
| field=field_type, | |
| x=monitor_x, | |
| y=monitor_y, | |
| T=float(T), | |
| snapshot_time=snapshot_dt, | |
| nx=nx, | |
| impulse_pos=impulse_pos, | |
| shots=10000, | |
| pm_optimization_level=2, | |
| simulation="False", | |
| optimization="True", | |
| platform="IBM", | |
| progress_callback=_ibm_progress_callback, | |
| print_callback=log_to_console, | |
| ) | |
| field_values = await loop.run_in_executor(executor, _run_ibm_qpu) | |
| # Build time frames to match the output | |
| times = ibm_create_time_frames(float(T), snapshot_dt) | |
| # Build Plotly figure for the single time series | |
| import plotly.graph_objects as go | |
| fig = go.Figure() | |
| # Determine grid dimensions for label | |
| if field_type == 'Ez': | |
| gw, gh = nx, nx | |
| elif field_type == 'Hx': | |
| gw, gh = nx, nx - 1 | |
| else: | |
| gw, gh = nx - 1, nx | |
| from .utils import normalized_position_label | |
| label = normalized_position_label(monitor_x, monitor_y, gw, gh) | |
| # Color based on field type | |
| if field_type == 'Ez': | |
| color = "#d32f2f" # Red | |
| elif field_type == 'Hx': | |
| color = "#388e3c" # Green | |
| else: | |
| color = "#1976d2" # Blue | |
| fig.add_trace( | |
| go.Scatter( | |
| x=list(times), | |
| y=[float(v) for v in field_values], | |
| mode='lines+markers', | |
| name=f"{field_type} @ {label}", | |
| line=dict(color=color, width=2.5), | |
| marker=dict(size=7, symbol="circle", color=color), | |
| hovertemplate=f"{field_type} | t=%{{x:.3f}}s<br>Value=%{{y:.6g}}<extra>{label}</extra>", | |
| ) | |
| ) | |
| max_abs = max((abs(float(v)) for v in field_values), default=1.0) | |
| pad = 0.12 * max_abs if max_abs > 0 else 0.1 | |
| fig.update_layout( | |
| title=f"IBM QPU Time Series - {field_type} @ {label}", | |
| height=660, width=900, | |
| margin=dict(l=50, r=30, t=50, b=50), | |
| hovermode="x unified", | |
| legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1, title_text=""), | |
| paper_bgcolor="#FFFFFF", | |
| plot_bgcolor="#FFFFFF", | |
| ) | |
| fig.update_xaxes(title_text="Time (s)", title_font=dict(size=22), tickfont=dict(size=16), showgrid=True, gridcolor="rgba(0,0,0,.06)") | |
| fig.update_yaxes(title_text="Field Value", title_font=dict(size=22), tickfont=dict(size=16), showgrid=True, gridcolor="rgba(0,0,0,.06)") | |
| fig.update_yaxes(range=[-max_abs - pad, max_abs + pad]) | |
| # Cache the figure for export | |
| qpu_ts_cache["fig"] = fig | |
| qpu_ts_cache["times"] = list(times) | |
| qpu_ts_cache["series_map"] = {(field_type, monitor_x, monitor_y): list(field_values)} | |
| qpu_ts_cache["field"] = field_type | |
| qpu_ts_cache["unique_fields"] = [field_type] | |
| try: | |
| ctrl.qpu_ts_update(fig) | |
| except Exception: | |
| pass | |
| state.simulation_has_run = True | |
| state.run_button_text = "Successful!" | |
| state.simulation_progress = 100 | |
| state.status_message = "IBM QPU simulation completed successfully!" | |
| log_to_console("IBM QPU run completed") | |
| state.status_type = "success" | |
| state.show_progress = False | |
| _auto_hide_status_window(3.0) # Auto-hide after 3 seconds | |
| await _flush_async() # Update UI with completion status | |
| ready = bool(field_values) and len(field_values) > 0 | |
| state.qpu_ts_ready = ready | |
| state.qpu_plot_style = ( | |
| "width: 900px; height: 660px; margin: 0 auto;" | |
| if ready else "display: none; width: 900px; height: 660px; margin: 0 auto;" | |
| ) | |
| state.qpu_ts_other_ready = False | |
| state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" | |
| # Set filter options for single result | |
| state.qpu_plot_field_options = ["All", field_type] | |
| state.qpu_plot_filter = "All" | |
| state.qpu_plot_position_options = ["All positions", label] | |
| state.qpu_plot_position_filter = "All positions" | |
| if not ready: | |
| state.error_message = "No IBM QPU time series generated. Check Δt, T, nx, and monitor position." | |
| state.status_message = "Warning: No IBM QPU time series generated." | |
| state.status_type = "warning" | |
| log_to_console("IBM QPU complete.") | |
| except Exception as e: | |
| import traceback | |
| state.error_message = f"IBM QPU run failed: {e}" | |
| state.status_message = f"IBM QPU Error: {e}" | |
| state.status_type = "error" | |
| state.show_progress = False | |
| state.run_button_text = "RUN!" | |
| state.qpu_ts_ready = False | |
| log_to_console(f"IBM QPU error: {e}") | |
| log_to_console(traceback.format_exc()) | |
| finally: | |
| state.is_running = False | |
| state.stop_button_disabled = True | |
| _stop_progress_heartbeat() | |
| executor.shutdown(wait=False) | |
| await _flush_async() | |
| return | |
| # IonQ QPU branch | |
| ionq_qpu_selected = state.backend_type == "QPU" and state.selected_qpu == "IonQ QPU" | |
| if ionq_qpu_selected: | |
| try: | |
| log_to_console("Running IonQ QPU simulation...") | |
| state.status_message = "Running IonQ QPU simulation..." | |
| state.simulation_progress = 5 | |
| await _flush_async() | |
| # Import IonQ QPU backend (same module as IBM, different platform param) | |
| try: | |
| from quantum.utils.EBU_Quantum.no_body.base_functions import get_field_values as ionq_get_field_values, create_time_frames as ionq_create_time_frames | |
| except ModuleNotFoundError: | |
| from utils.EBU_Quantum.no_body.base_functions import get_field_values as ionq_get_field_values, create_time_frames as ionq_create_time_frames | |
| # Inputs for IonQ QPU (single field, single position only!) | |
| snapshot_dt = float(state.dt_user) | |
| ix_imp, iy_imp = nearest_node_index(float(state.impulse_x), float(state.impulse_y), nx) | |
| impulse_pos = (ix_imp, iy_imp) | |
| # Get field and single position from UI | |
| field_type = (state.qpu_field_components or "Ez").strip() | |
| if field_type == "All": | |
| field_type = "Ez" | |
| log_to_console("Warning: IonQ QPU only supports single field. Defaulting to Ez.") | |
| # Parse single monitor position | |
| pts_str = str(state.qpu_monitor_gridpoints or "").strip() | |
| raw_pts = [tuple(map(int, m)) for m in re.findall(r"\((\d+)\s*,\s*(\d+)\)", pts_str)] | |
| if not raw_pts: | |
| monitor_x, monitor_y = impulse_pos | |
| log_to_console(f"No monitor position specified. Using impulse position ({monitor_x}, {monitor_y}).") | |
| else: | |
| monitor_x, monitor_y = raw_pts[0] | |
| if len(raw_pts) > 1: | |
| log_to_console(f"Warning: IonQ QPU only supports single position. Using first: ({monitor_x}, {monitor_y})") | |
| state.status_message = "Step 1: Generating circuit..." | |
| state.simulation_progress = 0 | |
| await _flush_async() | |
| def _ionq_progress_callback(pct, message=None): | |
| """Progress callback for IonQ QPU.""" | |
| state.simulation_progress = int(pct) | |
| if message: | |
| state.status_message = message | |
| elif pct < 10: | |
| state.status_message = f"Step 1: Generating circuit ({int(pct)}%)" | |
| elif pct < 60: | |
| state.status_message = f"Step 2: Optimising circuit ({int(pct)}%)" | |
| elif pct < 90: | |
| state.status_message = f"Step 3: Job execution ({int(pct)}%)" | |
| else: | |
| state.status_message = f"Step 4: Creating plots ({int(pct)}%)" | |
| _flush_state_threadsafe() | |
| # Call the IonQ QPU get_field_values function in executor | |
| def _run_ionq_qpu(): | |
| return ionq_get_field_values( | |
| field=field_type, | |
| x=monitor_x, | |
| y=monitor_y, | |
| T=float(T), | |
| snapshot_time=snapshot_dt, | |
| nx=nx, | |
| impulse_pos=impulse_pos, | |
| shots=10000, | |
| pm_optimization_level=1, # IonQ recommended | |
| simulation="False", | |
| optimization="True", | |
| platform="IONQ", # <-- Key difference from IBM | |
| progress_callback=_ionq_progress_callback, | |
| print_callback=log_to_console, | |
| ) | |
| field_values = await loop.run_in_executor(executor, _run_ionq_qpu) | |
| # Build time frames to match the output | |
| times = ionq_create_time_frames(float(T), snapshot_dt) | |
| # Build Plotly figure for the single time series | |
| import plotly.graph_objects as go | |
| fig = go.Figure() | |
| # Determine grid dimensions for label | |
| if field_type == 'Ez': | |
| gw, gh = nx, nx | |
| elif field_type == 'Hx': | |
| gw, gh = nx, nx - 1 | |
| else: | |
| gw, gh = nx - 1, nx | |
| from .utils import normalized_position_label | |
| label = normalized_position_label(monitor_x, monitor_y, gw, gh) | |
| # Color based on field type | |
| if field_type == 'Ez': | |
| color = "#d32f2f" | |
| elif field_type == 'Hx': | |
| color = "#388e3c" | |
| else: | |
| color = "#1976d2" | |
| fig.add_trace( | |
| go.Scatter( | |
| x=list(times), | |
| y=[float(v) for v in field_values], | |
| mode='lines+markers', | |
| name=f"{field_type} @ {label}", | |
| line=dict(color=color, width=2.5), | |
| marker=dict(size=7, symbol="circle", color=color), | |
| hovertemplate=f"{field_type} | t=%{{x:.3f}}s<br>Value=%{{y:.6g}}<extra>{label}</extra>", | |
| ) | |
| ) | |
| max_abs = max((abs(float(v)) for v in field_values), default=1.0) | |
| pad = 0.12 * max_abs if max_abs > 0 else 0.1 | |
| fig.update_layout( | |
| title=f"IonQ QPU Time Series - {field_type} @ {label}", | |
| height=660, width=900, | |
| margin=dict(l=50, r=30, t=50, b=50), | |
| hovermode="x unified", | |
| legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1, title_text=""), | |
| paper_bgcolor="#FFFFFF", | |
| plot_bgcolor="#FFFFFF", | |
| ) | |
| fig.update_xaxes(title_text="Time (s)", title_font=dict(size=22), tickfont=dict(size=16), showgrid=True, gridcolor="rgba(0,0,0,.06)") | |
| fig.update_yaxes(title_text="Field Value", title_font=dict(size=22), tickfont=dict(size=16), showgrid=True, gridcolor="rgba(0,0,0,.06)") | |
| fig.update_yaxes(range=[-max_abs - pad, max_abs + pad]) | |
| # Cache the figure for export | |
| qpu_ts_cache["fig"] = fig | |
| qpu_ts_cache["times"] = list(times) | |
| qpu_ts_cache["series_map"] = {(field_type, monitor_x, monitor_y): list(field_values)} | |
| qpu_ts_cache["field"] = field_type | |
| qpu_ts_cache["unique_fields"] = [field_type] | |
| try: | |
| ctrl.qpu_ts_update(fig) | |
| except Exception: | |
| pass | |
| state.simulation_has_run = True | |
| state.run_button_text = "Successful!" | |
| state.simulation_progress = 100 | |
| state.status_message = "IonQ QPU simulation completed successfully!" | |
| log_to_console("IonQ QPU run completed") | |
| state.status_type = "success" | |
| state.show_progress = False | |
| _auto_hide_status_window(3.0) | |
| await _flush_async() | |
| ready = bool(field_values) and len(field_values) > 0 | |
| state.qpu_ts_ready = ready | |
| state.qpu_plot_style = ( | |
| "width: 900px; height: 660px; margin: 0 auto;" | |
| if ready else "display: none; width: 900px; height: 660px; margin: 0 auto;" | |
| ) | |
| state.qpu_ts_other_ready = False | |
| state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" | |
| # Set filter options for single result | |
| state.qpu_plot_field_options = ["All", field_type] | |
| state.qpu_plot_filter = "All" | |
| state.qpu_plot_position_options = ["All positions", label] | |
| state.qpu_plot_position_filter = "All positions" | |
| if not ready: | |
| state.error_message = "No IonQ QPU time series generated. Check Δt, T, nx, and monitor position." | |
| state.status_message = "Warning: No IonQ QPU time series generated." | |
| state.status_type = "warning" | |
| log_to_console("IonQ QPU complete.") | |
| except Exception as e: | |
| import traceback | |
| state.error_message = f"IonQ QPU run failed: {e}" | |
| state.status_message = f"IonQ QPU Error: {e}" | |
| state.status_type = "error" | |
| state.show_progress = False | |
| state.run_button_text = "RUN!" | |
| state.qpu_ts_ready = False | |
| log_to_console(f"IonQ QPU error: {e}") | |
| log_to_console(traceback.format_exc()) | |
| finally: | |
| state.is_running = False | |
| state.stop_button_disabled = True | |
| _stop_progress_heartbeat() | |
| executor.shutdown(wait=False) | |
| await _flush_async() | |
| return | |
| # Simulator path - run blocking simulation in executor | |
| log_to_console("Running simulation...") | |
| state.status_message = "Running simulation... This may take a while." | |
| state.simulation_progress = 30 | |
| await _flush_async() | |
| snapshot_dt = float(state.dt_user) | |
| def _stop_check(): | |
| return g.stop_simulation | |
| state.simulation_progress = 50 | |
| await _flush_async() | |
| # Run the blocking simulation in a thread pool to keep UI responsive | |
| def _run_blocking_sim(): | |
| return run_sim( | |
| nx, na, R, initial_state, T, | |
| snapshot_dt=snapshot_dt, | |
| stop_check=_stop_check, | |
| progress_callback=_progress_callback, | |
| print_callback=log_to_console | |
| ) | |
| try: | |
| sim_data, times = await loop.run_in_executor(executor, _run_blocking_sim) | |
| except Exception as e: | |
| state.error_message = f"Simulation error: {e}" | |
| state.status_message = f"Error: {e}" | |
| state.status_type = "error" | |
| state.show_progress = False | |
| state.is_running = False | |
| state.run_button_text = "RUN!" | |
| state.stop_button_disabled = True | |
| _stop_progress_heartbeat() | |
| await _flush_async() | |
| executor.shutdown(wait=False) | |
| return | |
| g.simulation_data = sim_data | |
| g.snapshot_times = times | |
| log_to_console("Simulation complete.") | |
| state.simulation_progress = 80 | |
| state.status_message = "Processing simulation results..." | |
| await _flush_async() | |
| if sim_data.size > 0: | |
| setup_surface_plot_data(sim_data, nx) | |
| state.simulation_has_run = True | |
| state.run_button_text = "Successful!" | |
| state.simulation_progress = 100 | |
| state.status_message = "Simulation completed successfully!" | |
| state.status_type = "success" | |
| state.show_progress = False | |
| _auto_hide_status_window(3.0) # Auto-hide after 3 seconds | |
| generate_plot() | |
| else: | |
| state.error_message = "Simulation produced no data. Check parameters (e.g., T > 0)." | |
| state.status_message = "Error: Simulation produced no data." | |
| state.status_type = "error" | |
| state.show_progress = False | |
| state.run_button_text = "RUN!" | |
| state.is_running = False | |
| state.stop_button_disabled = True | |
| _stop_progress_heartbeat() | |
| await _flush_async() | |
| # Cleanup executor | |
| executor.shutdown(wait=False) | |
| def reset_to_defaults(): | |
| """Reset all parameters to their default values.""" | |
| from .excitation import update_initial_state_preview, update_sim_monitor_points | |
| from . import globals as g | |
| # Stop any running simulation | |
| set_stop_simulation(True) | |
| # Reset global variables | |
| reset_globals() | |
| # Reset state to default values | |
| state.update({ | |
| "dist_type": None, | |
| "impulse_x": 0.5, | |
| "impulse_y": 0.5, | |
| "peak_pair": "(0.5, 0.5)", | |
| "mu_x": 0.5, | |
| "mu_y": 0.5, | |
| "sigma_x": 0.25, | |
| "sigma_y": 0.15, | |
| "mu_pair": "(0.5, 0.5)", | |
| "sigma_pair": "(0.25, 0.15)", | |
| "nx": None, | |
| "T": 10.0, | |
| "time_val": 0.0, | |
| "output_type": "Surface Plot", | |
| "surface_field": "Ez", | |
| "timeseries_field": "Ez", | |
| "timeseries_points": "(0.5, 0.5)", | |
| "timeseries_gridpoints": "", | |
| "timeseries_point_info": "", | |
| "error_message": "", | |
| "excitation_info_message": "", | |
| "excitation_config_open": False, | |
| "is_running": False, | |
| "simulation_has_run": False, | |
| "geometry_selection": None, | |
| "coeff_permittivity": 1.0, | |
| "coeff_permeability": 1.0, | |
| "run_button_text": "RUN!", | |
| "backend_type": None, | |
| "selected_simulator": "IBM Qiskit simulator", | |
| "selected_qpu": "IBM QPU", | |
| "stop_button_disabled": True, | |
| "export_format": "vtk", | |
| "nx_slider_index": None, | |
| "dt_user": 0.1, | |
| "temporal_warning": "", | |
| "qpu_field_components": "Ez", | |
| "qpu_monitor_gridpoints": "", | |
| "qpu_monitor_samples": "(0.5, 0.5)", | |
| "qpu_monitor_sample_info": "", | |
| "qpu_monitor_count": 0, | |
| "qpu_plot_filter": "All", | |
| "qpu_plot_field_options": ["All"], | |
| "qpu_plot_position_filter": "All positions", | |
| "qpu_plot_position_options": ["All positions"], | |
| "qpu_ts_ready": False, | |
| "qpu_plot_style": "display: none; width: 900px; height: 660px; margin: 0 auto;", | |
| "qpu_ts_other_ready": False, | |
| "qpu_other_plot_style": "display: none; width: 900px; height: 660px; margin: 0 auto;", | |
| "pyvista_view_style": "aspect-ratio: 1 / 1; width: 100%;", | |
| }) | |
| # Reset QPU cache | |
| qpu_ts_cache.update({ | |
| "times": None, | |
| "series_map": None, | |
| "field": None, | |
| "fig": None, | |
| "positions_by_field": {"All": []}, | |
| "key_to_label": {}, | |
| "label_to_keys": {}, | |
| "nx": None, | |
| }) | |
| # Ensure stop flag is cleared for next run | |
| set_stop_simulation(False) | |
| # Update monitors | |
| update_sim_monitor_points() | |
| _apply_workflow_highlights(0) | |
| # Update the preview with default values | |
| update_initial_state_preview() | |
| print("Reset to default settings") | |
| def stop_simulation_handler(): | |
| """Stop the currently running simulation.""" | |
| set_stop_simulation(True) | |
| state.status_message = "Stopping simulation..." | |
| state.status_type = "warning" | |
| log_to_console("Stopping simulation...") | |
| # --------------------------------------------------------------------------- | |
| # Grid overlay helpers for PyVista plots | |
| # --------------------------------------------------------------------------- | |
| def add_dotted_unit_grid(pl, ticks=(0.0, 0.25, 0.5, 0.75, 1.0), segments=48, gap_ratio=0.4, color="#AE8BD8", line_width=0.2): | |
| """Add a dotted unit grid (0..1) overlay in light Synopsys purple.""" | |
| import pyvista as pv | |
| try: | |
| step = 1.0 / float(max(segments, 1)) | |
| seg_len = step * float(max(0.0, min(1.0, 1.0 - gap_ratio))) | |
| pts = [] | |
| lines = [] | |
| # Horizontal dotted lines at given y=tick | |
| for y in ticks: | |
| pos = 0.0 | |
| while pos < 1.0 - 1e-9: | |
| y0, y1 = pos, min(pos + seg_len, 1.0) | |
| pts.extend([(0.0, y, 0.0), (1.0, y, 0.0)]) | |
| pts[-2] = (pos, y, 0.0) | |
| pts[-1] = (y1 if seg_len > 0 else pos, y, 0.0) | |
| i0 = len(pts) - 2 | |
| lines.extend([2, i0, i0 + 1]) | |
| pos += step | |
| # Vertical dotted lines at given x=tick | |
| for x in ticks: | |
| pos = 0.0 | |
| while pos < 1.0 - 1e-9: | |
| y0, y1 = pos, min(pos + seg_len, 1.0) | |
| pts.extend([(x, pos, 0.0), (x, y1 if seg_len > 0 else pos, 0.0)]) | |
| i0 = len(pts) - 2 | |
| lines.extend([2, i0, i0 + 1]) | |
| pos += step | |
| if pts and lines: | |
| poly = pv.PolyData(np.array(pts)) | |
| poly.lines = np.array(lines) | |
| pl.add_mesh(poly, color=color, line_width=line_width, name="dotted_unit_grid", pickable=False) | |
| except Exception: | |
| pass | |
| def add_dotted_unit_grid_scaled(pl, denom, ticks=(0.0, 0.25, 0.5, 0.75, 1.0), segments=48, gap_ratio=0.6, color="#AE8BD8", line_width=1.0, name="dotted_unit_grid_preview"): | |
| """Overlay a 0–1 dotted grid scaled to [0, denom] on the XY plane.""" | |
| import pyvista as pv | |
| from . import globals as g | |
| try: | |
| step = 1.0 / float(max(segments, 1)) | |
| seg_len = step * float(max(0.0, min(1.0, 1.0 - gap_ratio))) | |
| # Set a z slightly below mesh to avoid z-fighting | |
| try: | |
| z0 = float(g.current_mesh.points[:, 2].min()) - 1e-6 if g.current_mesh is not None else 0.0 | |
| except Exception: | |
| z0 = 0.0 | |
| pts, lines = [], [] | |
| # Vertical lines at x = t * denom | |
| for t in ticks: | |
| x = float(t) * float(denom) | |
| pos = 0.0 | |
| while pos < 1.0 - 1e-9: | |
| y0 = pos * denom | |
| y1 = min(pos + seg_len, 1.0) * denom | |
| pts.extend([(x, y0, z0), (x, y1, z0)]) | |
| i0 = len(pts) - 2 | |
| lines.extend([2, i0, i0 + 1]) | |
| pos += step | |
| # Horizontal lines at y = t * denom | |
| for t in ticks: | |
| y = float(t) * float(denom) | |
| pos = 0.0 | |
| while pos < 1.0 - 1e-9: | |
| x0 = pos * denom | |
| x1 = min(pos + seg_len, 1.0) * denom | |
| pts.extend([(x0, y, z0), (x1, y, z0)]) | |
| i0 = len(pts) - 2 | |
| lines.extend([2, i0, i0 + 1]) | |
| pos += step | |
| try: | |
| pl.remove_actor(name) | |
| except Exception: | |
| pass | |
| if pts and lines: | |
| poly = pv.PolyData(np.array(pts)) | |
| poly.lines = np.array(lines) | |
| pl.add_mesh(poly, color=color, line_width=line_width, name=name, pickable=False) | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Simulator timeseries plot builder | |
| # --------------------------------------------------------------------------- | |
| def build_sim_timeseries_plotly(field_type: str, positions, nx: int, times, sim_data): | |
| """Build a Plotly figure for simulator timeseries data.""" | |
| import plotly.graph_objects as go | |
| from matplotlib import cm as _cm | |
| from .utils import normalized_position_label | |
| try: | |
| def _rgba_to_hex(rgba): | |
| r, g, b, a = rgba | |
| return "#%02x%02x%02x" % (int(r*255), int(g*255), int(b*255)) | |
| n_frames = int(sim_data.shape[0]) if sim_data is not None else 0 | |
| time_axis = np.asarray(times) if times is not None else np.arange(n_frames) | |
| def _dims(f): | |
| if f == 'Ez': | |
| return nx, nx | |
| if f == 'Hx': | |
| return nx, nx - 1 | |
| return nx - 1, nx # Hy | |
| def _valid_positions(f, pts): | |
| gw, gh = _dims(f) | |
| out = [] | |
| for (px, py) in pts: | |
| if 0 <= px < gw and 0 <= py < gh: | |
| out.append((int(px), int(py))) | |
| return out | |
| fig = go.Figure() | |
| if not positions or sim_data is None or n_frames == 0: | |
| fig.update_layout( | |
| title="Time Series (Simulator)", | |
| height=660, width=900, | |
| margin=dict(l=50, r=30, t=50, b=50), | |
| xaxis=dict(title="Time (s)", title_font=dict(size=22), tickfont=dict(size=16), showline=True, linewidth=1, linecolor="rgba(0,0,0,.3)", gridcolor="rgba(0,0,0,.06)", showspikes=True, spikemode='across', spikesnap='cursor'), | |
| yaxis=dict(title="Field Amplitude", title_font=dict(size=22), tickfont=dict(size=16), showline=True, linewidth=1, linecolor="rgba(0,0,0,.3)", gridcolor="rgba(0,0,0,.06)", zeroline=True, zerolinecolor="rgba(0,0,0,.25)"), | |
| hovermode="x unified", | |
| legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1), | |
| ) | |
| return fig | |
| max_sum = max((px + py) for (px, py) in positions) if positions else 1 | |
| if max_sum <= 0: | |
| max_sum = 1 | |
| cmap_map = { | |
| 'Ez': _cm.Reds, | |
| 'Hx': _cm.Greens, | |
| 'Hy': _cm.Blues, | |
| } | |
| def _add_field_traces(f_name: str, pts): | |
| nonlocal fig | |
| gw, gh = _dims(f_name) | |
| valid_pts = _valid_positions(f_name, pts) | |
| if not valid_pts: | |
| return 0.0, 0 | |
| max_abs_local = 0.0 | |
| num_keys = len(valid_pts) | |
| for i, (px, py) in enumerate(valid_pts): | |
| if f_name == 'Ez': | |
| values = sim_data[:, py * gw + px] | |
| elif f_name == 'Hx': | |
| block = sim_data[:, 2*nx*nx : 3*nx*nx-nx].reshape(n_frames, gh, gw) | |
| values = block[:, py, px] | |
| else: # Hy | |
| mask = np.arange(1, nx * nx + 1) % nx != 0 | |
| raw_block = sim_data[:, -nx*nx:] | |
| values = np.array([raw_block[t, mask].reshape(nx, nx - 1)[py, px] for t in range(n_frames)]) | |
| try: | |
| max_abs_local = max(max_abs_local, float(np.max(np.abs(values)))) | |
| except Exception: | |
| pass | |
| if num_keys > 1: | |
| s_index = i / (num_keys - 1) | |
| s_light = 0.3 + 0.6 * s_index | |
| else: | |
| s_light = 0.6 | |
| rgba = cmap_map.get(f_name, _cm.Blues)(s_light) | |
| color_hex = _rgba_to_hex(rgba) | |
| dash_styles = ["solid", "dash", "dot", "dashdot"] | |
| marker_symbols = ["circle", "square", "diamond", "triangle-up", "x"] | |
| label = normalized_position_label(px, py, gw, gh) | |
| fig.add_trace(go.Scatter( | |
| x=time_axis, | |
| y=values, | |
| mode='lines+markers', | |
| name=label, | |
| line=dict(color=color_hex, width=2.5, dash=dash_styles[i % len(dash_styles)]), | |
| marker=dict(size=7, symbol=marker_symbols[i % len(marker_symbols)], color=color_hex, line=dict(width=0)), | |
| hovertemplate=f"{f_name} | t=%{{x:.3f}}s<br>Value=%{{y:.6g}}<extra>{label}</extra>", | |
| )) | |
| return max_abs_local, len(valid_pts) | |
| max_abs = 0.0 | |
| total_traces = 0 | |
| if str(field_type) == 'All': | |
| for f in ('Ez', 'Hx', 'Hy'): | |
| m, n_tr = _add_field_traces(f, positions) | |
| max_abs = max(max_abs, m) | |
| total_traces += n_tr | |
| else: | |
| m, n_tr = _add_field_traces(str(field_type), positions) | |
| max_abs = max(max_abs, m) | |
| total_traces += n_tr | |
| title_suffix = str(field_type) if str(field_type) != 'All' else 'Ez, Hx, Hy' | |
| fig.update_layout( | |
| title=f"Time Series (Simulator: {title_suffix})", | |
| height=660, width=900, | |
| margin=dict(l=50, r=30, t=50, b=50), | |
| hovermode="x unified", | |
| legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1, title_text=""), | |
| paper_bgcolor="#FFFFFF", | |
| plot_bgcolor="#FFFFFF", | |
| ) | |
| fig.update_xaxes( | |
| title_text="Time (s)", title_font=dict(size=22), tickfont=dict(size=16), | |
| showgrid=True, gridcolor="rgba(95,37,159,0.08)", zeroline=False, | |
| showline=True, linewidth=1, linecolor="rgba(0,0,0,.2)", | |
| showspikes=True, spikemode='across', spikesnap='cursor' | |
| ) | |
| fig.update_yaxes( | |
| title_text="Field Amplitude", title_font=dict(size=22), tickfont=dict(size=16), | |
| showgrid=True, gridcolor="rgba(95,37,159,0.08)", zeroline=True, zerolinecolor="rgba(0,0,0,.25)", | |
| showline=True, linewidth=1, linecolor="rgba(0,0,0,.2)" | |
| ) | |
| if max_abs > 0: | |
| pad = 0.12 * max_abs | |
| fig.update_yaxes(range=[-max_abs - pad, max_abs + pad]) | |
| return fig | |
| except Exception: | |
| import plotly.graph_objects as go | |
| return go.Figure(layout=dict(height=660, width=900)) | |
| # --------------------------------------------------------------------------- | |
| # Value display for picked points on the mesh | |
| # --------------------------------------------------------------------------- | |
| def update_value_display(point): | |
| """Update value display when a point is picked on the mesh.""" | |
| from . import globals as g | |
| if g.current_mesh is None: | |
| return | |
| try: | |
| plotter.remove_actor("value_text") | |
| except Exception: | |
| pass | |
| closest_id = g.current_mesh.find_closest_point(point) | |
| if closest_id == -1: | |
| return | |
| value = g.current_mesh['scalars'][closest_id] if 'scalars' in g.current_mesh.array_names else 0.0 | |
| px, py, pz = g.current_mesh.points[closest_id] | |
| px = float(px) | |
| py = float(py) | |
| xmin, xmax, ymin, ymax, _, _ = g.current_mesh.bounds | |
| is_unit_square = (xmax <= 1.00001 and ymax <= 1.00001) | |
| if not state.simulation_has_run and is_unit_square: | |
| text = f"Position: ({px:.3f}, {py:.3f})\nValue: {value:.3e}" | |
| else: | |
| nx_val = int(state.nx) | |
| denom = max(float(nx_val - 1), 1.0) | |
| if is_unit_square: | |
| ix = int(round(px * denom)) | |
| iy = int(round(py * denom)) | |
| x_code = max(0.0, min(1.0, px)) | |
| y_code = max(0.0, min(1.0, py)) | |
| else: | |
| ix = int(round(px)) | |
| iy = int(round(py)) | |
| x_code = max(0.0, min(1.0, px / denom)) | |
| y_code = max(0.0, min(1.0, py / denom)) | |
| ix = max(0, min(ix, nx_val - 1)) | |
| iy = max(0, min(iy, nx_val - 1)) | |
| if state.simulation_has_run: | |
| time = float(state.time_val) | |
| text = f"Index: ({ix}, {iy}) | Position: ({x_code:.3f}, {y_code:.3f})\nTime: {time:.2f}s\nValue: {value:.3e}" | |
| else: | |
| text = f"Index: ({ix}, {iy}) | Position: ({x_code:.3f}, {y_code:.3f})\nValue: {value:.3e}" | |
| plotter.add_text(text, name="value_text", position="lower_left", color="black", font_size=10) | |
| ctrl.view_update() | |
| # --------------------------------------------------------------------------- | |
| # EM Job Result Upload Processing | |
| # --------------------------------------------------------------------------- | |
| def process_uploaded_em_job_result(): | |
| """ | |
| Process an IBM/IonQ EM job by retrieving it directly using the Job ID and generate a time-series plot. | |
| This function: | |
| 1. Takes the Job ID from user input | |
| 2. Connects to IBM/IonQ based on platform selection and retrieves the job | |
| 3. Extracts expectation values (evs) from Estimator results and converts them to field magnitudes | |
| 3. Builds time frames based on user-specified T and dt | |
| 4. Generates a Plotly time-series figure | |
| Note: | |
| - This pathway expects the job was submitted by this EM workflow (Estimator-based). | |
| - The job is assumed to contain one expectation value per time frame. | |
| """ | |
| import os | |
| import plotly.graph_objects as go | |
| if not state.bound: | |
| return | |
| # Validate Job ID | |
| job_id = None | |
| if getattr(state, "em_job_id", None) and str(state.em_job_id).strip(): | |
| job_id = str(state.em_job_id).strip() | |
| if job_id.endswith(".json"): | |
| job_id = job_id[:-5] | |
| if not job_id: | |
| state.em_job_upload_error = "No Job ID provided. Please enter a Job ID." | |
| return | |
| # Reset messages | |
| state.em_job_upload_error = "" | |
| state.em_job_upload_success = "" | |
| state.em_job_is_processing = True | |
| try: | |
| from .simulation import log_to_console | |
| except ImportError: | |
| def log_to_console(msg): | |
| print(msg) | |
| log_to_console(f"Processing EM job result for Job ID: {job_id}") | |
| try: | |
| # Parse parameters from UI | |
| field_type = str(state.em_job_field_type or "Ez").strip() | |
| # Parse monitor point tuple string "(x, y)" | |
| monitor_point_str = str(state.em_job_monitor_point or "(0, 0)").strip() | |
| try: | |
| # Remove parentheses and split by comma | |
| cleaned = monitor_point_str.strip("() ") | |
| parts = [p.strip() for p in cleaned.split(",")] | |
| monitor_x = int(parts[0]) if len(parts) > 0 else 0 | |
| monitor_y = int(parts[1]) if len(parts) > 1 else 0 | |
| except (ValueError, IndexError): | |
| monitor_x, monitor_y = 0, 0 | |
| total_time = float(state.em_job_total_time or 1.0) | |
| snapshot_dt = float(state.em_job_snapshot_dt or 0.1) | |
| nx = int(state.em_job_nx or 4) | |
| platform = str(state.em_job_platform or "IBM") | |
| log_to_console(f"Parameters: field={field_type}, pos=({monitor_x},{monitor_y}), T={total_time}, dt={snapshot_dt}, nx={nx}, platform={platform}") | |
| # Retrieve job results from provider | |
| field_values = [] | |
| times = [] | |
| if platform.upper() == "IBM": | |
| try: | |
| from qiskit_ibm_runtime import QiskitRuntimeService | |
| except Exception: | |
| state.em_job_upload_error = "qiskit_ibm_runtime package not available. Please install it." | |
| state.em_job_is_processing = False | |
| return | |
| try: | |
| ibm_token = os.environ.get("API_KEY_IBM_EM") | |
| if not ibm_token or not str(ibm_token).strip(): | |
| state.em_job_upload_error = "IBM API token not found. Set API_KEY_IBM_EM environment variable." | |
| state.em_job_is_processing = False | |
| return | |
| service = QiskitRuntimeService( | |
| channel="ibm_cloud", | |
| token=ibm_token, | |
| instance="crn:v1:bluemix:public:quantum-computing:us-east:a/15157e4350c04a9dab51b8b8a4a93c86:e29afd91-64bf-4a82-8dbf-731e6c213595::", | |
| ) | |
| except Exception as e: | |
| state.em_job_upload_error = f"Failed to connect to IBM Quantum: {e}" | |
| state.em_job_is_processing = False | |
| return | |
| try: | |
| job = service.job(job_id) | |
| except Exception as e: | |
| state.em_job_upload_error = f"Failed to retrieve IBM job: {e}" | |
| state.em_job_is_processing = False | |
| return | |
| try: | |
| status = job.status() | |
| status_name = status.name if hasattr(status, "name") else str(status) | |
| if status_name not in ("DONE", "COMPLETED"): | |
| state.em_job_upload_error = f"Job is not complete. Current status: {status_name}" | |
| state.em_job_is_processing = False | |
| return | |
| except Exception: | |
| pass | |
| try: | |
| # Support both shapes: | |
| # - PrimitiveResult: iterable of pubs -> pub.data.evs | |
| # - list-like result where each entry has .data.evs | |
| res = job.result() | |
| if hasattr(res, "__iter__"): | |
| for pub in res: | |
| data = getattr(pub, "data", None) | |
| evs = getattr(data, "evs", None) if data is not None else None | |
| if evs is not None: | |
| z_exp = float(np.array(evs).reshape(-1)[0]) | |
| field_values.append(float(np.sqrt((1 - z_exp) / 2))) | |
| elif hasattr(res, "data") and hasattr(res.data, "evs"): | |
| z_exp = float(np.array(res.data.evs).reshape(-1)[0]) | |
| field_values.append(float(np.sqrt((1 - z_exp) / 2))) | |
| except Exception as e: | |
| state.em_job_upload_error = f"Failed to get job results: {e}" | |
| state.em_job_is_processing = False | |
| return | |
| else: | |
| # IonQ pathway (Estimator-based in this app) | |
| try: | |
| from qiskit_ionq import IonQProvider | |
| except Exception: | |
| state.em_job_upload_error = "qiskit_ionq package not available. Please install it." | |
| state.em_job_is_processing = False | |
| return | |
| ionq_token = os.environ.get("API_KEY_IONQ_EM") | |
| if not ionq_token or not str(ionq_token).strip(): | |
| state.em_job_upload_error = "IonQ API token not found. Set API_KEY_IONQ_EM environment variable." | |
| state.em_job_is_processing = False | |
| return | |
| os.environ.setdefault("IONQ_API_TOKEN", ionq_token) | |
| try: | |
| provider = IonQProvider() | |
| job = provider.retrieve_job(job_id) | |
| except Exception as e: | |
| state.em_job_upload_error = f"Failed to retrieve IonQ job: {e}" | |
| state.em_job_is_processing = False | |
| return | |
| try: | |
| status = job.status() | |
| status_name = status.name if hasattr(status, "name") else str(status) | |
| if status_name not in ("DONE", "COMPLETED"): | |
| state.em_job_upload_error = f"Job is not complete. Current status: {status_name}" | |
| state.em_job_is_processing = False | |
| return | |
| except Exception: | |
| pass | |
| try: | |
| res = job.result() | |
| if hasattr(res, "__iter__"): | |
| for pub in res: | |
| data = getattr(pub, "data", None) | |
| evs = getattr(data, "evs", None) if data is not None else None | |
| if evs is not None: | |
| z_exp = float(np.array(evs).reshape(-1)[0]) | |
| field_values.append(float(np.sqrt((1 - z_exp) / 2))) | |
| elif hasattr(res, "data") and hasattr(res.data, "evs"): | |
| z_exp = float(np.array(res.data.evs).reshape(-1)[0]) | |
| field_values.append(float(np.sqrt((1 - z_exp) / 2))) | |
| except Exception as e: | |
| state.em_job_upload_error = f"Failed to get job results: {e}" | |
| state.em_job_is_processing = False | |
| return | |
| if not field_values: | |
| state.em_job_upload_error = "No field values extracted from job. Ensure the job was submitted by the EM Estimator workflow." | |
| state.em_job_is_processing = False | |
| return | |
| # Generate times if not provided | |
| if not times: | |
| # Use create_time_frames from delta_impulse_generator | |
| try: | |
| times = create_time_frames(total_time, snapshot_dt) | |
| except: | |
| # Fallback: generate linearly | |
| num_steps = len(field_values) | |
| times = [i * snapshot_dt for i in range(num_steps)] | |
| # Ensure times matches field_values length | |
| if len(times) != len(field_values): | |
| log_to_console(f"Warning: times ({len(times)}) != field_values ({len(field_values)}), regenerating times") | |
| num_steps = len(field_values) | |
| times = [i * snapshot_dt for i in range(num_steps)] | |
| log_to_console(f"Building time-series plot: {len(field_values)} points") | |
| # Build Plotly figure | |
| fig = go.Figure() | |
| # Determine grid dimensions for label | |
| if field_type == 'Ez': | |
| gw, gh = nx, nx | |
| elif field_type == 'Hx': | |
| gw, gh = nx, nx - 1 | |
| else: | |
| gw, gh = nx - 1, nx | |
| from .utils import normalized_position_label | |
| label = normalized_position_label(monitor_x, monitor_y, gw, gh) | |
| # Color based on field type | |
| if field_type == 'Ez': | |
| color = "#d32f2f" # Red | |
| elif field_type == 'Hx': | |
| color = "#388e3c" # Green | |
| else: | |
| color = "#1976d2" # Blue | |
| fig.add_trace( | |
| go.Scatter( | |
| x=list(times), | |
| y=[float(v) for v in field_values], | |
| mode='lines+markers', | |
| name=f"{field_type} @ {label}", | |
| line=dict(color=color, width=2.5), | |
| marker=dict(size=7, symbol="circle", color=color), | |
| hovertemplate=f"{field_type} | t=%{{x:.3f}}s<br>Value=%{{y:.6g}}<extra>{label}</extra>", | |
| ) | |
| ) | |
| max_abs = max((abs(float(v)) for v in field_values), default=1.0) | |
| pad = 0.12 * max_abs if max_abs > 0 else 0.1 | |
| fig.update_layout( | |
| title=f"{platform} QPU Time Series (Uploaded) - {field_type} @ {label}", | |
| height=660, width=900, | |
| margin=dict(l=50, r=30, t=50, b=50), | |
| hovermode="x unified", | |
| legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1, title_text=""), | |
| paper_bgcolor="#FFFFFF", | |
| plot_bgcolor="#FFFFFF", | |
| ) | |
| fig.update_xaxes(title_text="Time (s)", title_font=dict(size=22), tickfont=dict(size=16), showgrid=True, gridcolor="rgba(0,0,0,.06)") | |
| fig.update_yaxes(title_text="Field Value", title_font=dict(size=22), tickfont=dict(size=16), showgrid=True, gridcolor="rgba(0,0,0,.06)") | |
| fig.update_yaxes(range=[-max_abs - pad, max_abs + pad]) | |
| # Cache the figure for export | |
| qpu_ts_cache["fig"] = fig | |
| qpu_ts_cache["times"] = list(times) | |
| qpu_ts_cache["series_map"] = {(field_type, monitor_x, monitor_y): list(field_values)} | |
| qpu_ts_cache["field"] = field_type | |
| qpu_ts_cache["unique_fields"] = [field_type] | |
| # Update the Plotly figure widget | |
| try: | |
| ctrl.qpu_ts_update(fig) | |
| except Exception: | |
| pass | |
| # Update state | |
| state.simulation_has_run = True | |
| state.qpu_ts_ready = True | |
| state.qpu_plot_style = "width: 900px; height: 660px; margin: 0 auto;" | |
| state.qpu_plot_field_options = ["All", field_type] | |
| state.qpu_plot_filter = "All" | |
| state.qpu_plot_position_options = ["All positions", label] | |
| state.qpu_plot_position_filter = "All positions" | |
| state.em_job_upload_success = f"✓ Successfully processed {len(field_values)} time step(s) from {platform} job {job_id}" | |
| log_to_console(f"Upload processing complete: {len(field_values)} points plotted") | |
| except Exception as e: | |
| state.em_job_upload_error = f"Error processing job result: {e}" | |
| log_to_console(f"Processing error: {e}") | |
| import traceback | |
| log_to_console(traceback.format_exc()) | |
| finally: | |
| state.em_job_is_processing = False |