quantum / em /qpu.py
harishaseebat92
updated the EM UI to show the nearest position message for IBM QPU/Statevector Estimator Output Preference
4bfc0a9
"""
EM Embedded - QPU Module
Contains QPU timeseries, caching, filters, and handlers.
"""
import re
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
from matplotlib import cm as _mpl_cm
from .state import state, ctrl
from .globals import qpu_ts_cache
from .utils import normalized_position_label, format_grid_label
# Import backend functions
try:
import quantum.utils.delta_impulse_generator as qutils
except ModuleNotFoundError:
import utils.delta_impulse_generator as qutils
__all__ = [
"cmap_for_field",
"update_qpu_position_options",
"filter_series_keys",
"refresh_qpu_plot_figures",
"build_qpu_timeseries_plotly_multi",
"rebuild_qpu_fig_filtered",
"rebuild_qpu_fig_others",
"on_qpu_ts_click",
"on_qpu_ts_clear",
"qpu_add_monitor_config",
"qpu_remove_monitor_config",
"qpu_set_monitor_field",
"qpu_set_monitor_points",
"qpu_set_plot_filter",
"qpu_set_plot_position_filter",
"qpu_add_monitor_slot",
"qpu_remove_monitor_slot",
# Internal functions needed by handlers
"hide_qpu_plots",
"update_qpu_sample_slot",
"refresh_all_qpu_sample_slots",
]
def hide_qpu_plots():
"""Hide QPU timeseries plots."""
state.qpu_ts_ready = False
state.qpu_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;"
state.qpu_ts_other_ready = False
state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;"
state.qpu_plot_position_options = ["All positions"]
state.qpu_plot_position_filter = "All positions"
def update_qpu_sample_slot(slot: int):
"""Update QPU sample slot with grid point info and nearest position message."""
try:
suffix = "" if slot == 1 else f"_{slot}"
samples_var = f"qpu_monitor_samples{suffix}"
info_var = f"qpu_monitor_sample_info{suffix}"
gp_var = f"qpu_monitor_gridpoints{suffix}"
samples = getattr(state, samples_var, "")
if not samples or not str(samples).strip():
setattr(state, info_var, "")
setattr(state, gp_var, "")
return
nx_val = state.nx
if nx_val is None:
setattr(state, info_var, "Select a grid size (nx) to compute the nearest grid position.")
setattr(state, gp_var, "")
return
# Import snap_samples_to_grid to get consistent position formatting
from .excitation import snap_samples_to_grid
snapped, message = snap_samples_to_grid(samples, int(nx_val))
setattr(state, gp_var, snapped)
setattr(state, info_var, message or "")
except Exception:
pass
def refresh_all_qpu_sample_slots():
"""Refresh all QPU sample slots."""
for slot in range(1, 6):
update_qpu_sample_slot(slot)
def cmap_for_field(field: str):
"""Choose colormap per field (Ez→Reds, Hx→Greens, Hy→Blues)."""
f = str(field)
if f == 'Ez':
return _mpl_cm.Reds
if f == 'Hx':
return _mpl_cm.Greens
return _mpl_cm.Blues
def update_qpu_position_options(current_field: str = "All"):
"""Update QPU position filter options based on current field."""
try:
field_key = (current_field or "All").strip() or "All"
pos_map = qpu_ts_cache.get("positions_by_field") or {}
entries = pos_map.get(field_key) or pos_map.get("All") or []
labels = []
for entry in entries:
label = None
if isinstance(entry, dict):
label = entry.get("label")
if not label:
coords = entry.get("coords") or (None, None)
fld = entry.get("field")
label = format_grid_label(coords[0], coords[1], fld)
elif isinstance(entry, (list, tuple)) and len(entry) >= 2:
lbl_field = entry[2] if len(entry) >= 3 else (field_key if field_key not in ("", "All") else None)
label = format_grid_label(entry[0], entry[1], lbl_field)
if label:
labels.append(label)
labels = list(dict.fromkeys(labels))
options = ["All positions"] + labels if labels else ["All positions"]
state.qpu_plot_position_options = options
if state.qpu_plot_position_filter not in options:
state.qpu_plot_position_filter = options[0]
except Exception:
pass
def filter_series_keys(series_map, field_filter: str, position_filter: str):
"""Filter series keys based on field and position filters."""
keys = list(series_map.keys())
ff = (field_filter or "All").strip()
pf = (position_filter or "All positions").strip()
if ff not in ("", "All"):
keys = [k for k in keys if str(k[0]) == ff]
if pf not in ("", "All", "All positions"):
label_map = qpu_ts_cache.get("label_to_keys") or {}
label_keys = label_map.get(pf)
if not label_keys:
return []
allowed = {(str(fld), int(px), int(py)) for (fld, px, py) in label_keys}
keys = [k for k in keys if (str(k[0]), int(k[1]), int(k[2])) in allowed]
return keys
def refresh_qpu_plot_figures():
"""Refresh QPU plot figures with current filter settings."""
try:
field_filter = (state.qpu_plot_filter or "All").strip()
except Exception:
field_filter = "All"
try:
position_filter = (state.qpu_plot_position_filter or "All positions").strip()
except Exception:
position_filter = "All positions"
fig_all = qpu_ts_cache.get("fig")
times = qpu_ts_cache.get("times") or []
series_map = qpu_ts_cache.get("series_map") or {}
if fig_all is None or not times or not series_map:
return
update_qpu_position_options(field_filter)
fig_primary = rebuild_qpu_fig_filtered(field_filter, position_filter)
if fig_primary is None:
fig_primary = fig_all
if fig_primary is not None:
try:
ctrl.qpu_ts_update(fig_primary)
except Exception:
pass
state.qpu_ts_ready = True
state.qpu_plot_style = "width: 900px; height: 660px; margin: 0 auto;"
else:
state.qpu_ts_ready = False
state.qpu_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;"
if field_filter not in ("", "All") and position_filter in ("", "All", "All positions"):
fig_oth = rebuild_qpu_fig_others(field_filter, position_filter)
if fig_oth is not None and getattr(fig_oth, "data", None):
try:
ctrl.qpu_ts_other_update(fig_oth)
except Exception:
pass
state.qpu_ts_other_ready = True
state.qpu_other_plot_style = "width: 900px; height: 660px; margin: 0 auto;"
else:
state.qpu_ts_other_ready = False
state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;"
else:
state.qpu_ts_other_ready = False
state.qpu_other_plot_style = "display: none; width: 900px; height: 660px; margin: 0 auto;"
def build_qpu_timeseries_plotly_multi(configs, nx: int, T: float, snapshot_dt: float, impulse_pos, *, series_runner, progress_callback=None, print_callback=None):
"""Build multi-config Plotly time series using the provided series runner."""
if series_runner is None:
raise ValueError("series_runner callable is required for QPU timeseries builds.")
times = qutils.create_time_frames(T, snapshot_dt)
fig = go.Figure()
all_triplets = []
cfg_expanded = []
for cfg in (configs or []):
field_type = (cfg.get("field") or "Ez").strip()
pts_str = str(cfg.get("points") or "").strip()
fields = ('Ez', 'Hx', 'Hy') if field_type == 'All' else (field_type,)
raw_pts = [tuple(map(int, m)) for m in re.findall(r"\((\d+)\s*,\s*(\d+)\)", pts_str)] or [impulse_pos]
for f in fields:
if f == 'Ez':
gw, gh = nx, nx
elif f == 'Hx':
gw, gh = nx, nx - 1
else:
gw, gh = nx - 1, nx
valid = []
for (px, py) in raw_pts:
if 0 <= px < gw and 0 <= py < gh:
valid.append((int(px), int(py)))
if not valid:
continue
cfg_expanded.append((f, valid))
all_triplets.extend((f, px, py) for (px, py) in valid)
max_sum = max(((px + py) for (_, px, py) in all_triplets), default=1)
if max_sum <= 0:
max_sum = 1
series_map = {}
positions_by_field = defaultdict(dict)
key_to_label = {}
label_to_keys = defaultdict(set)
max_abs = 0.0
dashes = ["solid", "dash", "dot", "dashdot"]
markers = ["circle", "square", "diamond", "triangle-up", "x"]
total_configs = len(cfg_expanded)
for idx, (field_type, valid_positions) in enumerate(cfg_expanded):
def _sub_progress(p):
if progress_callback:
base = (idx / total_configs) * 100
fraction = (1 / total_configs) * 100
total_p = base + (p / 100.0) * fraction
progress_callback(total_p)
try:
series_map_field = series_runner(
field_type=field_type,
positions=valid_positions,
total_time=float(T),
snapshot_dt=float(snapshot_dt),
nx=int(nx),
impulse_pos=impulse_pos,
progress_callback=_sub_progress,
print_callback=print_callback,
)
except Exception as e:
msg = f"QPU error for {field_type} positions {valid_positions}: {e}"
if print_callback:
print_callback(msg)
continue
cmap = cmap_for_field(field_type)
num_pts = len(valid_positions)
for i, (px, py) in enumerate(valid_positions):
ys = (series_map_field or {}).get((px, py), [])
if not ys or len(ys) != len(times):
continue
series_map[(field_type, px, py)] = list(ys)
try:
max_abs = max(max_abs, max((abs(float(v)) for v in ys)))
except Exception:
pass
if num_pts > 1:
s_index = i / (num_pts - 1)
s_light = 0.3 + 0.6 * s_index
else:
s_light = 0.6
rgba = cmap(s_light)
color_hex = f"#{int(rgba[0]*255):02x}{int(rgba[1]*255):02x}{int(rgba[2]*255):02x}"
if field_type == 'Ez':
gw, gh = nx, nx
elif field_type == 'Hx':
gw, gh = nx, nx - 1
else:
gw, gh = nx - 1, nx
label = normalized_position_label(px, py, gw, gh)
key = (str(field_type), int(px), int(py))
key_to_label[key] = label
label_to_keys[label].add(key)
positions_by_field[str(field_type)][(int(px), int(py))] = {
"coords": (int(px), int(py)),
"label": label,
"field": str(field_type),
}
fig.add_trace(
go.Scatter(
x=times,
y=ys,
mode='lines+markers',
name=label,
line=dict(color=color_hex, width=2.5, dash=dashes[i % len(dashes)]),
marker=dict(size=7, symbol=markers[i % len(markers)], color=color_hex),
hovertemplate=f"{field_type} | t=%{{x:.3f}}s<br>Value=%{{y:.6g}}<extra>{label}</extra>",
)
)
unique_fields = sorted({f for (f, _, _) in series_map.keys()})
fig.update_layout(
title=f"Time Series ({', '.join(unique_fields) if unique_fields else '—'})",
height=660, width=900,
margin=dict(l=50, r=30, t=50, b=50),
hovermode="x unified",
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1, title_text=""),
paper_bgcolor="#FFFFFF",
plot_bgcolor="#FFFFFF",
)
fig.update_xaxes(title_text="Time (s)", title_font=dict(size=22), tickfont=dict(size=16), showgrid=True, gridcolor="rgba(0,0,0,.06)")
fig.update_yaxes(title_text="Field Value", title_font=dict(size=22), tickfont=dict(size=16), showgrid=True, gridcolor="rgba(0,0,0,.06)")
if max_abs > 0:
pad = 0.12 * max_abs
fig.update_yaxes(range=[-max_abs - pad, max_abs + pad])
# Update cache
qpu_ts_cache["times"] = list(times)
qpu_ts_cache["series_map"] = series_map
qpu_ts_cache["field"] = ",".join(unique_fields) if len(unique_fields) == 1 else ("multi" if unique_fields else "")
qpu_ts_cache["fig"] = fig
qpu_ts_cache["unique_fields"] = list(unique_fields)
try:
positions_map_sorted = {}
all_entries = {}
for field_name, entry_map in positions_by_field.items():
entries = [entry_map[key] for key in sorted(entry_map.keys(), key=lambda xy: (xy[0], xy[1]))]
positions_map_sorted[field_name] = entries
for entry in entries:
all_entries.setdefault(entry["label"], entry)
positions_map_sorted["All"] = sorted(all_entries.values(), key=lambda entry: (entry["coords"][0], entry["coords"][1]))
qpu_ts_cache["positions_by_field"] = positions_map_sorted
qpu_ts_cache["key_to_label"] = key_to_label
qpu_ts_cache["label_to_keys"] = {lbl: sorted(list(vals)) for lbl, vals in label_to_keys.items()}
qpu_ts_cache["nx"] = int(nx)
except Exception:
qpu_ts_cache["positions_by_field"] = {"All": []}
qpu_ts_cache["key_to_label"] = {}
qpu_ts_cache["label_to_keys"] = {}
try:
state.qpu_plot_field_options = ["All"] + list(unique_fields)
state.qpu_plot_filter = "All"
update_qpu_position_options("All")
except Exception:
pass
return fig
def rebuild_qpu_fig_filtered(filter_value: str, position_filter: str = "All positions"):
"""Rebuild QPU figure with field/position filters applied."""
try:
fv = (filter_value or "All").strip()
pf = (position_filter or "All positions").strip()
fig_all = qpu_ts_cache.get("fig")
times = qpu_ts_cache.get("times") or []
series_map = qpu_ts_cache.get("series_map") or {}
if fig_all is None or not times or not series_map:
return fig_all
use_base = fv in ("", "All") and pf in ("", "All", "All positions")
if use_base:
return fig_all
keys = filter_series_keys(series_map, fv, pf)
if not keys:
return None
fig = go.Figure()
dashes = ["solid", "dash", "dot", "dashdot"]
markers = ["circle", "square", "diamond", "triangle-up", "x"]
max_abs = 0.0
label_map = qpu_ts_cache.get("key_to_label") or {}
sorted_keys = sorted(keys, key=lambda x: (str(x[0]), x[1], x[2]))
num_keys = len(sorted_keys)
for i, k in enumerate(sorted_keys):
field_name, px, py = k
ys = series_map.get(k) or []
if not ys or len(ys) != len(times):
continue
try:
max_abs = max(max_abs, max((abs(float(v)) for v in ys)))
except Exception:
pass
cmap = cmap_for_field(field_name)
if num_keys > 1:
s_index = i / (num_keys - 1)
s_light = 0.3 + 0.6 * s_index
else:
s_light = 0.6
rgba = cmap(s_light)
color_hex = f"#{int(rgba[0]*255):02x}{int(rgba[1]*255):02x}{int(rgba[2]*255):02x}"
label = label_map.get((str(field_name), int(px), int(py))) or format_grid_label(px, py, field_name)
fig.add_trace(go.Scatter(
x=times,
y=ys,
mode='lines+markers',
name=label,
line=dict(color=color_hex, width=2.5, dash=dashes[i % len(dashes)]),
marker=dict(size=7, symbol=markers[i % len(markers)], color=color_hex),
hovertemplate=f"{field_name} | t=%{{x:.3f}}s<br>Value=%{{y:.6g}}<extra>{label}</extra>",
))
title_parts = []
if fv not in ("", "All"):
title_parts.append(fv)
if pf not in ("", "All", "All positions"):
title_parts.append(pf)
suffix = " - ".join(title_parts) if title_parts else "Filtered"
fig.update_layout(
title=f"IBM QPU Time Series ({suffix})",
height=660, width=900,
margin=dict(l=50, r=30, t=50, b=50),
hovermode="x unified",
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1)
)
fig.update_xaxes(title_text="Time (s)", title_font=dict(size=22), tickfont=dict(size=16))
fig.update_yaxes(title_text="Field Value", title_font=dict(size=22), tickfont=dict(size=16))
if max_abs > 0:
pad = 0.12 * max_abs
fig.update_yaxes(range=[-max_abs - pad, max_abs + pad])
return fig
except Exception:
return qpu_ts_cache.get("fig")
def rebuild_qpu_fig_others(selected_field: str, position_filter: str = "All positions"):
"""Build Plotly figure for all components except the selected one."""
try:
times = qpu_ts_cache.get("times") or []
series_map = qpu_ts_cache.get("series_map") or {}
if not times or not series_map:
return None
all_fields = sorted({str(k[0]) for k in series_map.keys()})
other_fields = [f for f in all_fields if f != selected_field]
if not other_fields:
return None
keys = [k for k in series_map.keys() if str(k[0]) in other_fields]
if not keys:
return None
fig = go.Figure()
dashes = ["solid", "dash", "dot", "dashdot"]
markers = ["circle", "square", "diamond", "triangle-up", "x"]
max_abs = 0.0
label_map = qpu_ts_cache.get("key_to_label") or {}
sorted_keys = sorted(keys, key=lambda x: (str(x[0]), x[1], x[2]))
num_keys = len(sorted_keys)
for i, k in enumerate(sorted_keys):
field_name, px, py = k
ys = series_map.get(k) or []
if not ys or len(ys) != len(times):
continue
try:
max_abs = max(max_abs, max((abs(float(v)) for v in ys)))
except Exception:
pass
cmap = cmap_for_field(field_name)
s_light = 0.6 if num_keys == 1 else 0.3 + 0.6 * (i / (num_keys - 1))
rgba = cmap(s_light)
color_hex = f"#{int(rgba[0]*255):02x}{int(rgba[1]*255):02x}{int(rgba[2]*255):02x}"
label = label_map.get((str(field_name), int(px), int(py))) or format_grid_label(px, py, field_name)
fig.add_trace(go.Scatter(
x=times,
y=ys,
mode='lines+markers',
name=label,
line=dict(color=color_hex, width=2.5, dash=dashes[i % len(dashes)]),
marker=dict(size=7, symbol=markers[i % len(markers)], color=color_hex),
hovertemplate=f"{field_name} | t=%{{x:.3f}}s<br>Value=%{{y:.6g}}<extra>{label}</extra>",
))
fig.update_layout(
title=f"Other Components ({', '.join(other_fields)})",
height=660, width=900,
margin=dict(l=50, r=30, t=50, b=50),
hovermode="x unified",
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1)
)
if max_abs > 0:
pad = 0.12 * max_abs
fig.update_yaxes(range=[-max_abs - pad, max_abs + pad])
return fig
except Exception:
return None
# Click handlers
def on_qpu_ts_click(evt):
"""Handle click on QPU time series plot."""
try:
if not evt or "points" not in evt or not evt["points"]:
return
x = float(evt["points"][0].get("x"))
times = qpu_ts_cache.get("times") or []
fig = qpu_ts_cache.get("fig")
if not times or fig is None:
return
idx = int(np.argmin(np.abs(np.asarray(times) - x)))
sel_t = float(times[idx])
fig.update_layout(shapes=[dict(
type="line", x0=sel_t, x1=sel_t, y0=0, y1=1,
xref="x", yref="paper",
line=dict(color="#5F259F", width=2, dash="dot")
)])
qpu_ts_cache["fig"] = fig
try:
ctrl.qpu_ts_update(fig)
except Exception:
pass
state.qpu_ts_selected_time = sel_t
except Exception:
pass
def on_qpu_ts_clear():
"""Clear QPU time series selection."""
try:
fig = qpu_ts_cache.get("fig")
if fig is None:
return
fig.update_layout(shapes=[])
qpu_ts_cache["fig"] = fig
try:
ctrl.qpu_ts_update(fig)
except Exception:
pass
state.qpu_ts_selected_time = None
except Exception:
pass
# Register click handlers on controller
ctrl.on_qpu_ts_click = on_qpu_ts_click
ctrl.on_qpu_ts_clear = on_qpu_ts_clear
# Monitor config management
def qpu_add_monitor_config():
"""Add a new QPU monitor configuration."""
from .globals import new_monitor_cfg
configs = list(state.qpu_monitor_configs or [])
configs.append(new_monitor_cfg())
state.qpu_monitor_configs = configs
def qpu_remove_monitor_config(index):
"""Remove a QPU monitor configuration by index."""
configs = list(state.qpu_monitor_configs or [])
if 0 <= index < len(configs):
configs.pop(index)
state.qpu_monitor_configs = configs
def qpu_set_monitor_field(index, value):
"""Set the field for a monitor config."""
configs = list(state.qpu_monitor_configs or [])
if 0 <= index < len(configs):
configs[index]["field"] = value
state.qpu_monitor_configs = configs
def qpu_set_monitor_points(index, value):
"""Set the points for a monitor config."""
configs = list(state.qpu_monitor_configs or [])
if 0 <= index < len(configs):
configs[index]["points"] = value
state.qpu_monitor_configs = configs
def qpu_set_plot_filter(value):
"""Set the component filter and refresh chart."""
state.qpu_plot_filter = value
update_qpu_position_options(value)
refresh_qpu_plot_figures()
def qpu_set_plot_position_filter(value):
"""Set the position filter and refresh chart."""
state.qpu_plot_position_filter = value
refresh_qpu_plot_figures()
def qpu_add_monitor_slot():
"""Add a new QPU monitor slot."""
try:
cnt = int(state.qpu_monitor_count or 0)
except Exception:
cnt = 0
if cnt < 4:
state.qpu_monitor_count = cnt + 1
def qpu_remove_monitor_slot(slot_index):
"""Remove a QPU monitor slot."""
try:
cnt = int(state.qpu_monitor_count or 0)
except Exception:
cnt = 0
if slot_index <= cnt:
# Shift remaining slots up
for i in range(slot_index, cnt):
src = i + 1
setattr(state, f"qpu_field_components_{i}", getattr(state, f"qpu_field_components_{src}", "Ez"))
setattr(state, f"qpu_monitor_gridpoints_{i}", getattr(state, f"qpu_monitor_gridpoints_{src}", ""))
setattr(state, f"qpu_monitor_samples_{i}", getattr(state, f"qpu_monitor_samples_{src}", ""))
setattr(state, f"qpu_monitor_sample_info_{i}", getattr(state, f"qpu_monitor_sample_info_{src}", ""))
state.qpu_monitor_count = max(0, cnt - 1)
# Register on controller
ctrl.qpu_remove_monitor_config = qpu_remove_monitor_config
ctrl.qpu_set_monitor_field = qpu_set_monitor_field
ctrl.qpu_set_monitor_points = qpu_set_monitor_points
ctrl.qpu_set_plot_filter = qpu_set_plot_filter
ctrl.qpu_set_plot_position_filter = qpu_set_plot_position_filter
ctrl.qpu_add_monitor_slot = qpu_add_monitor_slot
ctrl.qpu_remove_monitor_slot = qpu_remove_monitor_slot