""" EM Embedded - QPU Module Contains QPU timeseries, caching, filters, and handlers. """ import re import numpy as np import plotly.graph_objects as go from collections import defaultdict from matplotlib import cm as _mpl_cm from .state import state, ctrl from .globals import qpu_ts_cache from .utils import normalized_position_label, format_grid_label # Import backend functions try: import quantum.utils.delta_impulse_generator as qutils except ModuleNotFoundError: import utils.delta_impulse_generator as qutils __all__ = [ "cmap_for_field", "update_qpu_position_options", "filter_series_keys", "refresh_qpu_plot_figures", "build_qpu_timeseries_plotly_multi", "rebuild_qpu_fig_filtered", "rebuild_qpu_fig_others", "on_qpu_ts_click", "on_qpu_ts_clear", "qpu_add_monitor_config", "qpu_remove_monitor_config", "qpu_set_monitor_field", "qpu_set_monitor_points", "qpu_set_plot_filter", "qpu_set_plot_position_filter", "qpu_add_monitor_slot", "qpu_remove_monitor_slot", # Internal functions needed by handlers "hide_qpu_plots", "update_qpu_sample_slot", "refresh_all_qpu_sample_slots", ] def hide_qpu_plots(): """Hide QPU timeseries plots.""" state.qpu_ts_ready = False state.qpu_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" state.qpu_ts_other_ready = False state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" state.qpu_plot_position_options = ["All positions"] state.qpu_plot_position_filter = "All positions" def update_qpu_sample_slot(slot: int): """Update QPU sample slot with grid point info and nearest position message.""" try: suffix = "" if slot == 1 else f"_{slot}" samples_var = f"qpu_monitor_samples{suffix}" info_var = f"qpu_monitor_sample_info{suffix}" gp_var = f"qpu_monitor_gridpoints{suffix}" samples = getattr(state, samples_var, "") if not samples or not str(samples).strip(): setattr(state, info_var, "") setattr(state, gp_var, "") return nx_val = state.nx if nx_val is None: setattr(state, info_var, "Select a grid size (nx) to compute the nearest grid position.") setattr(state, gp_var, "") return # Import snap_samples_to_grid to get consistent position formatting from .excitation import snap_samples_to_grid snapped, message = snap_samples_to_grid(samples, int(nx_val)) setattr(state, gp_var, snapped) setattr(state, info_var, message or "") except Exception: pass def refresh_all_qpu_sample_slots(): """Refresh all QPU sample slots.""" for slot in range(1, 6): update_qpu_sample_slot(slot) def cmap_for_field(field: str): """Choose colormap per field (Ez→Reds, Hx→Greens, Hy→Blues).""" f = str(field) if f == 'Ez': return _mpl_cm.Reds if f == 'Hx': return _mpl_cm.Greens return _mpl_cm.Blues def update_qpu_position_options(current_field: str = "All"): """Update QPU position filter options based on current field.""" try: field_key = (current_field or "All").strip() or "All" pos_map = qpu_ts_cache.get("positions_by_field") or {} entries = pos_map.get(field_key) or pos_map.get("All") or [] labels = [] for entry in entries: label = None if isinstance(entry, dict): label = entry.get("label") if not label: coords = entry.get("coords") or (None, None) fld = entry.get("field") label = format_grid_label(coords[0], coords[1], fld) elif isinstance(entry, (list, tuple)) and len(entry) >= 2: lbl_field = entry[2] if len(entry) >= 3 else (field_key if field_key not in ("", "All") else None) label = format_grid_label(entry[0], entry[1], lbl_field) if label: labels.append(label) labels = list(dict.fromkeys(labels)) options = ["All positions"] + labels if labels else ["All positions"] state.qpu_plot_position_options = options if state.qpu_plot_position_filter not in options: state.qpu_plot_position_filter = options[0] except Exception: pass def filter_series_keys(series_map, field_filter: str, position_filter: str): """Filter series keys based on field and position filters.""" keys = list(series_map.keys()) ff = (field_filter or "All").strip() pf = (position_filter or "All positions").strip() if ff not in ("", "All"): keys = [k for k in keys if str(k[0]) == ff] if pf not in ("", "All", "All positions"): label_map = qpu_ts_cache.get("label_to_keys") or {} label_keys = label_map.get(pf) if not label_keys: return [] allowed = {(str(fld), int(px), int(py)) for (fld, px, py) in label_keys} keys = [k for k in keys if (str(k[0]), int(k[1]), int(k[2])) in allowed] return keys def refresh_qpu_plot_figures(): """Refresh QPU plot figures with current filter settings.""" try: field_filter = (state.qpu_plot_filter or "All").strip() except Exception: field_filter = "All" try: position_filter = (state.qpu_plot_position_filter or "All positions").strip() except Exception: position_filter = "All positions" fig_all = qpu_ts_cache.get("fig") times = qpu_ts_cache.get("times") or [] series_map = qpu_ts_cache.get("series_map") or {} if fig_all is None or not times or not series_map: return update_qpu_position_options(field_filter) fig_primary = rebuild_qpu_fig_filtered(field_filter, position_filter) if fig_primary is None: fig_primary = fig_all if fig_primary is not None: try: ctrl.qpu_ts_update(fig_primary) except Exception: pass state.qpu_ts_ready = True state.qpu_plot_style = "width: 900px; height: 660px; margin: 0 auto;" else: state.qpu_ts_ready = False state.qpu_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" if field_filter not in ("", "All") and position_filter in ("", "All", "All positions"): fig_oth = rebuild_qpu_fig_others(field_filter, position_filter) if fig_oth is not None and getattr(fig_oth, "data", None): try: ctrl.qpu_ts_other_update(fig_oth) except Exception: pass state.qpu_ts_other_ready = True state.qpu_other_plot_style = "width: 900px; height: 660px; margin: 0 auto;" else: state.qpu_ts_other_ready = False state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" else: state.qpu_ts_other_ready = False state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;" def build_qpu_timeseries_plotly_multi(configs, nx: int, T: float, snapshot_dt: float, impulse_pos, *, series_runner, progress_callback=None, print_callback=None): """Build multi-config Plotly time series using the provided series runner.""" if series_runner is None: raise ValueError("series_runner callable is required for QPU timeseries builds.") times = qutils.create_time_frames(T, snapshot_dt) fig = go.Figure() all_triplets = [] cfg_expanded = [] for cfg in (configs or []): field_type = (cfg.get("field") or "Ez").strip() pts_str = str(cfg.get("points") or "").strip() fields = ('Ez', 'Hx', 'Hy') if field_type == 'All' else (field_type,) raw_pts = [tuple(map(int, m)) for m in re.findall(r"\((\d+)\s*,\s*(\d+)\)", pts_str)] or [impulse_pos] for f in fields: if f == 'Ez': gw, gh = nx, nx elif f == 'Hx': gw, gh = nx, nx - 1 else: gw, gh = nx - 1, nx valid = [] for (px, py) in raw_pts: if 0 <= px < gw and 0 <= py < gh: valid.append((int(px), int(py))) if not valid: continue cfg_expanded.append((f, valid)) all_triplets.extend((f, px, py) for (px, py) in valid) max_sum = max(((px + py) for (_, px, py) in all_triplets), default=1) if max_sum <= 0: max_sum = 1 series_map = {} positions_by_field = defaultdict(dict) key_to_label = {} label_to_keys = defaultdict(set) max_abs = 0.0 dashes = ["solid", "dash", "dot", "dashdot"] markers = ["circle", "square", "diamond", "triangle-up", "x"] total_configs = len(cfg_expanded) for idx, (field_type, valid_positions) in enumerate(cfg_expanded): def _sub_progress(p): if progress_callback: base = (idx / total_configs) * 100 fraction = (1 / total_configs) * 100 total_p = base + (p / 100.0) * fraction progress_callback(total_p) try: series_map_field = series_runner( field_type=field_type, positions=valid_positions, total_time=float(T), snapshot_dt=float(snapshot_dt), nx=int(nx), impulse_pos=impulse_pos, progress_callback=_sub_progress, print_callback=print_callback, ) except Exception as e: msg = f"QPU error for {field_type} positions {valid_positions}: {e}" if print_callback: print_callback(msg) continue cmap = cmap_for_field(field_type) num_pts = len(valid_positions) for i, (px, py) in enumerate(valid_positions): ys = (series_map_field or {}).get((px, py), []) if not ys or len(ys) != len(times): continue series_map[(field_type, px, py)] = list(ys) try: max_abs = max(max_abs, max((abs(float(v)) for v in ys))) except Exception: pass if num_pts > 1: s_index = i / (num_pts - 1) s_light = 0.3 + 0.6 * s_index else: s_light = 0.6 rgba = cmap(s_light) color_hex = f"#{int(rgba[0]*255):02x}{int(rgba[1]*255):02x}{int(rgba[2]*255):02x}" if field_type == 'Ez': gw, gh = nx, nx elif field_type == 'Hx': gw, gh = nx, nx - 1 else: gw, gh = nx - 1, nx label = normalized_position_label(px, py, gw, gh) key = (str(field_type), int(px), int(py)) key_to_label[key] = label label_to_keys[label].add(key) positions_by_field[str(field_type)][(int(px), int(py))] = { "coords": (int(px), int(py)), "label": label, "field": str(field_type), } fig.add_trace( go.Scatter( x=times, y=ys, mode='lines+markers', name=label, line=dict(color=color_hex, width=2.5, dash=dashes[i % len(dashes)]), marker=dict(size=7, symbol=markers[i % len(markers)], color=color_hex), hovertemplate=f"{field_type} | t=%{{x:.3f}}s
Value=%{{y:.6g}}{label}", ) ) unique_fields = sorted({f for (f, _, _) in series_map.keys()}) fig.update_layout( title=f"Time Series ({', '.join(unique_fields) if unique_fields else '—'})", height=660, width=900, margin=dict(l=50, r=30, t=50, b=50), hovermode="x unified", legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1, title_text=""), paper_bgcolor="#FFFFFF", plot_bgcolor="#FFFFFF", ) fig.update_xaxes(title_text="Time (s)", title_font=dict(size=22), tickfont=dict(size=16), showgrid=True, gridcolor="rgba(0,0,0,.06)") fig.update_yaxes(title_text="Field Value", title_font=dict(size=22), tickfont=dict(size=16), showgrid=True, gridcolor="rgba(0,0,0,.06)") if max_abs > 0: pad = 0.12 * max_abs fig.update_yaxes(range=[-max_abs - pad, max_abs + pad]) # Update cache qpu_ts_cache["times"] = list(times) qpu_ts_cache["series_map"] = series_map qpu_ts_cache["field"] = ",".join(unique_fields) if len(unique_fields) == 1 else ("multi" if unique_fields else "") qpu_ts_cache["fig"] = fig qpu_ts_cache["unique_fields"] = list(unique_fields) try: positions_map_sorted = {} all_entries = {} for field_name, entry_map in positions_by_field.items(): entries = [entry_map[key] for key in sorted(entry_map.keys(), key=lambda xy: (xy[0], xy[1]))] positions_map_sorted[field_name] = entries for entry in entries: all_entries.setdefault(entry["label"], entry) positions_map_sorted["All"] = sorted(all_entries.values(), key=lambda entry: (entry["coords"][0], entry["coords"][1])) qpu_ts_cache["positions_by_field"] = positions_map_sorted qpu_ts_cache["key_to_label"] = key_to_label qpu_ts_cache["label_to_keys"] = {lbl: sorted(list(vals)) for lbl, vals in label_to_keys.items()} qpu_ts_cache["nx"] = int(nx) except Exception: qpu_ts_cache["positions_by_field"] = {"All": []} qpu_ts_cache["key_to_label"] = {} qpu_ts_cache["label_to_keys"] = {} try: state.qpu_plot_field_options = ["All"] + list(unique_fields) state.qpu_plot_filter = "All" update_qpu_position_options("All") except Exception: pass return fig def rebuild_qpu_fig_filtered(filter_value: str, position_filter: str = "All positions"): """Rebuild QPU figure with field/position filters applied.""" try: fv = (filter_value or "All").strip() pf = (position_filter or "All positions").strip() fig_all = qpu_ts_cache.get("fig") times = qpu_ts_cache.get("times") or [] series_map = qpu_ts_cache.get("series_map") or {} if fig_all is None or not times or not series_map: return fig_all use_base = fv in ("", "All") and pf in ("", "All", "All positions") if use_base: return fig_all keys = filter_series_keys(series_map, fv, pf) if not keys: return None fig = go.Figure() dashes = ["solid", "dash", "dot", "dashdot"] markers = ["circle", "square", "diamond", "triangle-up", "x"] max_abs = 0.0 label_map = qpu_ts_cache.get("key_to_label") or {} sorted_keys = sorted(keys, key=lambda x: (str(x[0]), x[1], x[2])) num_keys = len(sorted_keys) for i, k in enumerate(sorted_keys): field_name, px, py = k ys = series_map.get(k) or [] if not ys or len(ys) != len(times): continue try: max_abs = max(max_abs, max((abs(float(v)) for v in ys))) except Exception: pass cmap = cmap_for_field(field_name) if num_keys > 1: s_index = i / (num_keys - 1) s_light = 0.3 + 0.6 * s_index else: s_light = 0.6 rgba = cmap(s_light) color_hex = f"#{int(rgba[0]*255):02x}{int(rgba[1]*255):02x}{int(rgba[2]*255):02x}" label = label_map.get((str(field_name), int(px), int(py))) or format_grid_label(px, py, field_name) fig.add_trace(go.Scatter( x=times, y=ys, mode='lines+markers', name=label, line=dict(color=color_hex, width=2.5, dash=dashes[i % len(dashes)]), marker=dict(size=7, symbol=markers[i % len(markers)], color=color_hex), hovertemplate=f"{field_name} | t=%{{x:.3f}}s
Value=%{{y:.6g}}{label}", )) title_parts = [] if fv not in ("", "All"): title_parts.append(fv) if pf not in ("", "All", "All positions"): title_parts.append(pf) suffix = " - ".join(title_parts) if title_parts else "Filtered" fig.update_layout( title=f"IBM QPU Time Series ({suffix})", height=660, width=900, margin=dict(l=50, r=30, t=50, b=50), hovermode="x unified", legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1) ) fig.update_xaxes(title_text="Time (s)", title_font=dict(size=22), tickfont=dict(size=16)) fig.update_yaxes(title_text="Field Value", title_font=dict(size=22), tickfont=dict(size=16)) if max_abs > 0: pad = 0.12 * max_abs fig.update_yaxes(range=[-max_abs - pad, max_abs + pad]) return fig except Exception: return qpu_ts_cache.get("fig") def rebuild_qpu_fig_others(selected_field: str, position_filter: str = "All positions"): """Build Plotly figure for all components except the selected one.""" try: times = qpu_ts_cache.get("times") or [] series_map = qpu_ts_cache.get("series_map") or {} if not times or not series_map: return None all_fields = sorted({str(k[0]) for k in series_map.keys()}) other_fields = [f for f in all_fields if f != selected_field] if not other_fields: return None keys = [k for k in series_map.keys() if str(k[0]) in other_fields] if not keys: return None fig = go.Figure() dashes = ["solid", "dash", "dot", "dashdot"] markers = ["circle", "square", "diamond", "triangle-up", "x"] max_abs = 0.0 label_map = qpu_ts_cache.get("key_to_label") or {} sorted_keys = sorted(keys, key=lambda x: (str(x[0]), x[1], x[2])) num_keys = len(sorted_keys) for i, k in enumerate(sorted_keys): field_name, px, py = k ys = series_map.get(k) or [] if not ys or len(ys) != len(times): continue try: max_abs = max(max_abs, max((abs(float(v)) for v in ys))) except Exception: pass cmap = cmap_for_field(field_name) s_light = 0.6 if num_keys == 1 else 0.3 + 0.6 * (i / (num_keys - 1)) rgba = cmap(s_light) color_hex = f"#{int(rgba[0]*255):02x}{int(rgba[1]*255):02x}{int(rgba[2]*255):02x}" label = label_map.get((str(field_name), int(px), int(py))) or format_grid_label(px, py, field_name) fig.add_trace(go.Scatter( x=times, y=ys, mode='lines+markers', name=label, line=dict(color=color_hex, width=2.5, dash=dashes[i % len(dashes)]), marker=dict(size=7, symbol=markers[i % len(markers)], color=color_hex), hovertemplate=f"{field_name} | t=%{{x:.3f}}s
Value=%{{y:.6g}}{label}", )) fig.update_layout( title=f"Other Components ({', '.join(other_fields)})", height=660, width=900, margin=dict(l=50, r=30, t=50, b=50), hovermode="x unified", legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1) ) if max_abs > 0: pad = 0.12 * max_abs fig.update_yaxes(range=[-max_abs - pad, max_abs + pad]) return fig except Exception: return None # Click handlers def on_qpu_ts_click(evt): """Handle click on QPU time series plot.""" try: if not evt or "points" not in evt or not evt["points"]: return x = float(evt["points"][0].get("x")) times = qpu_ts_cache.get("times") or [] fig = qpu_ts_cache.get("fig") if not times or fig is None: return idx = int(np.argmin(np.abs(np.asarray(times) - x))) sel_t = float(times[idx]) fig.update_layout(shapes=[dict( type="line", x0=sel_t, x1=sel_t, y0=0, y1=1, xref="x", yref="paper", line=dict(color="#5F259F", width=2, dash="dot") )]) qpu_ts_cache["fig"] = fig try: ctrl.qpu_ts_update(fig) except Exception: pass state.qpu_ts_selected_time = sel_t except Exception: pass def on_qpu_ts_clear(): """Clear QPU time series selection.""" try: fig = qpu_ts_cache.get("fig") if fig is None: return fig.update_layout(shapes=[]) qpu_ts_cache["fig"] = fig try: ctrl.qpu_ts_update(fig) except Exception: pass state.qpu_ts_selected_time = None except Exception: pass # Register click handlers on controller ctrl.on_qpu_ts_click = on_qpu_ts_click ctrl.on_qpu_ts_clear = on_qpu_ts_clear # Monitor config management def qpu_add_monitor_config(): """Add a new QPU monitor configuration.""" from .globals import new_monitor_cfg configs = list(state.qpu_monitor_configs or []) configs.append(new_monitor_cfg()) state.qpu_monitor_configs = configs def qpu_remove_monitor_config(index): """Remove a QPU monitor configuration by index.""" configs = list(state.qpu_monitor_configs or []) if 0 <= index < len(configs): configs.pop(index) state.qpu_monitor_configs = configs def qpu_set_monitor_field(index, value): """Set the field for a monitor config.""" configs = list(state.qpu_monitor_configs or []) if 0 <= index < len(configs): configs[index]["field"] = value state.qpu_monitor_configs = configs def qpu_set_monitor_points(index, value): """Set the points for a monitor config.""" configs = list(state.qpu_monitor_configs or []) if 0 <= index < len(configs): configs[index]["points"] = value state.qpu_monitor_configs = configs def qpu_set_plot_filter(value): """Set the component filter and refresh chart.""" state.qpu_plot_filter = value update_qpu_position_options(value) refresh_qpu_plot_figures() def qpu_set_plot_position_filter(value): """Set the position filter and refresh chart.""" state.qpu_plot_position_filter = value refresh_qpu_plot_figures() def qpu_add_monitor_slot(): """Add a new QPU monitor slot.""" try: cnt = int(state.qpu_monitor_count or 0) except Exception: cnt = 0 if cnt < 4: state.qpu_monitor_count = cnt + 1 def qpu_remove_monitor_slot(slot_index): """Remove a QPU monitor slot.""" try: cnt = int(state.qpu_monitor_count or 0) except Exception: cnt = 0 if slot_index <= cnt: # Shift remaining slots up for i in range(slot_index, cnt): src = i + 1 setattr(state, f"qpu_field_components_{i}", getattr(state, f"qpu_field_components_{src}", "Ez")) setattr(state, f"qpu_monitor_gridpoints_{i}", getattr(state, f"qpu_monitor_gridpoints_{src}", "")) setattr(state, f"qpu_monitor_samples_{i}", getattr(state, f"qpu_monitor_samples_{src}", "")) setattr(state, f"qpu_monitor_sample_info_{i}", getattr(state, f"qpu_monitor_sample_info_{src}", "")) state.qpu_monitor_count = max(0, cnt - 1) # Register on controller ctrl.qpu_remove_monitor_config = qpu_remove_monitor_config ctrl.qpu_set_monitor_field = qpu_set_monitor_field ctrl.qpu_set_monitor_points = qpu_set_monitor_points ctrl.qpu_set_plot_filter = qpu_set_plot_filter ctrl.qpu_set_plot_position_filter = qpu_set_plot_position_filter ctrl.qpu_add_monitor_slot = qpu_add_monitor_slot ctrl.qpu_remove_monitor_slot = qpu_remove_monitor_slot