QureadAI / quread /heatmap.py
hchevva's picture
Upload heatmap.py
d0e98cb verified
# quread/heatmap.py
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Any
import numpy as np
import matplotlib.pyplot as plt
from .metrics import compute_metrics_from_csv, MetricWeights, MetricThresholds
_HAS_PLOTLY = False
try:
import plotly.graph_objects as go
_HAS_PLOTLY = True
except Exception:
_HAS_PLOTLY = False
logger = logging.getLogger(__name__)
@dataclass
class HeatmapConfig:
rows: int = 8 # chip rows
cols: int = 8 # chip cols
missing_value: float = 0.0
def _default_qubit_coords(n_qubits: int, rows: int, cols: int) -> Dict[int, Tuple[int, int]]:
"""
Simple default mapping: q0.. mapped row-major across the chip grid.
Override later with user-uploaded mapping if needed.
"""
m: Dict[int, Tuple[int, int]] = {}
for q in range(n_qubits):
r = q // cols
c = q % cols
if r >= rows:
break
m[q] = (r, c)
return m
_METRIC_ALIASES = {
"activity": "activity_count",
"activity_count": "activity_count",
"activity_norm": "activity_norm",
"gate_error": "gate_error",
"readout_error": "readout_error",
"coherence_health": "coherence_health",
"decoherence_time": "coherence_health",
"decoherence_times": "coherence_health",
"decoherence": "decoherence_risk",
"decoherence_risk": "decoherence_risk",
"fidelity": "fidelity",
"state_fidelity": "state_fidelity",
"process_fidelity": "process_fidelity",
"composite": "composite_risk",
"composite_risk": "composite_risk",
}
_METRIC_TITLES = {
"activity_count": "Qubit activity count (from circuit CSV)",
"activity_norm": "Normalized qubit activity (0-1)",
"gate_error": "Gate error rate heatmap",
"readout_error": "Readout error rate heatmap",
"coherence_health": "Coherence health heatmap (long-lived qubits are deep blue)",
"decoherence_risk": "Decoherence risk heatmap",
"fidelity": "Fidelity heatmap",
"state_fidelity": "State fidelity heatmap",
"process_fidelity": "Process fidelity heatmap",
"composite_risk": "Composite reliability risk heatmap",
}
_METRIC_CMAP = {
"activity_count": "viridis",
"activity_norm": "viridis",
"gate_error": "Reds",
"readout_error": "Reds",
"coherence_health": "Blues",
"decoherence_risk": "inferno",
"fidelity": "Greens",
"state_fidelity": "Greens",
"process_fidelity": "Greens",
"composite_risk": "hot",
}
def _resolve_metric(metric: str) -> str:
m = (metric or "activity_count").strip().lower()
return _METRIC_ALIASES.get(m, "activity_count")
def plotly_available() -> bool:
return bool(_HAS_PLOTLY)
def _build_metric_grid(
csv_text: str,
n_qubits: int,
metric: str,
cfg: HeatmapConfig,
calibration_json: str,
state_vector: Optional[np.ndarray],
weights: Optional[MetricWeights],
thresholds: Optional[MetricThresholds],
qubit_coords: Optional[Dict[int, Tuple[int, int]]],
) -> Tuple[str, np.ndarray, np.ndarray, Dict[str, Any], Dict[int, Tuple[int, int]]]:
metric_key = _resolve_metric(metric)
coords = qubit_coords or _default_qubit_coords(n_qubits, cfg.rows, cfg.cols)
grid = np.full((cfg.rows, cfg.cols), cfg.missing_value, dtype=float)
metrics, meta = compute_metrics_from_csv(
csv_text,
int(n_qubits),
calibration_json=calibration_json,
state_vector=state_vector,
weights=weights,
thresholds=thresholds,
)
values = metrics[metric_key]
for q, (rr, cc) in coords.items():
if 0 <= q < n_qubits and 0 <= rr < cfg.rows and 0 <= cc < cfg.cols:
grid[rr, cc] = values[q]
return metric_key, grid, values, meta, coords
def make_metric_heatmap(
csv_text: str,
n_qubits: int,
metric: str = "activity_count",
cfg: Optional[HeatmapConfig] = None,
calibration_json: str = "",
state_vector: Optional[np.ndarray] = None,
weights: Optional[MetricWeights] = None,
thresholds: Optional[MetricThresholds] = None,
highlight_threshold: Optional[float] = None,
qubit_coords: Optional[Dict[int, Tuple[int, int]]] = None,
) -> plt.Figure:
"""
Builds a heatmap for a selected metric computed from circuit CSV and optional calibration JSON.
"""
cfg = cfg or HeatmapConfig()
metric_key, grid, values, meta, coords = _build_metric_grid(
csv_text,
int(n_qubits),
str(metric),
cfg,
calibration_json,
state_vector,
weights,
thresholds,
qubit_coords,
)
# plot
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(grid, interpolation="nearest", cmap=_METRIC_CMAP[metric_key])
ax.set_title(_METRIC_TITLES[metric_key])
ax.set_xlabel("Chip column")
ax.set_ylabel("Chip row")
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
skipped = int(meta.get("skipped_rows", 0))
if skipped:
logger.warning("Skipped %d malformed CSV rows while building heatmap.", skipped)
ax.text(
0.01,
1.02,
f"Skipped {skipped} malformed CSV row(s)",
transform=ax.transAxes,
fontsize=8,
color="#b91c1c",
ha="left",
va="bottom",
)
calibration_note = str(meta.get("calibration_note", "") or "").strip()
if calibration_note:
ax.text(
0.99,
1.02,
calibration_note,
transform=ax.transAxes,
fontsize=8,
color="#92400e",
ha="right",
va="bottom",
)
# annotate qubit ids on mapped cells
for q, (rr, cc) in coords.items():
if 0 <= rr < cfg.rows and 0 <= cc < cfg.cols:
ax.text(cc, rr, f"q{q}", ha="center", va="center", fontsize=9)
if highlight_threshold is not None:
thr = float(np.clip(float(highlight_threshold), 0.0, 1e9))
highlighted = 0
for q, (rr, cc) in coords.items():
if 0 <= q < n_qubits and 0 <= rr < cfg.rows and 0 <= cc < cfg.cols:
if float(values[q]) >= thr:
highlighted += 1
ax.add_patch(
plt.Rectangle(
(cc - 0.5, rr - 0.5),
1.0,
1.0,
fill=False,
linewidth=2.0,
edgecolor="#f59e0b",
)
)
ax.text(
0.5,
-0.12,
f"Highlighted qubits (value >= {thr:.4g}): {highlighted}",
transform=ax.transAxes,
fontsize=8,
color="#92400e",
ha="center",
va="top",
)
ax.set_xticks(range(cfg.cols))
ax.set_yticks(range(cfg.rows))
fig.tight_layout()
return fig
def make_metric_heatmap_plotly(
csv_text: str,
n_qubits: int,
metric: str = "activity_count",
cfg: Optional[HeatmapConfig] = None,
calibration_json: str = "",
state_vector: Optional[np.ndarray] = None,
weights: Optional[MetricWeights] = None,
thresholds: Optional[MetricThresholds] = None,
highlight_threshold: Optional[float] = None,
qubit_coords: Optional[Dict[int, Tuple[int, int]]] = None,
) -> Any:
"""
Builds an interactive Plotly heatmap for zoom/pan exploration.
"""
if not _HAS_PLOTLY:
raise RuntimeError("Plotly is not available in this environment.")
cfg = cfg or HeatmapConfig()
metric_key, grid, values, meta, coords = _build_metric_grid(
csv_text,
int(n_qubits),
str(metric),
cfg,
calibration_json,
state_vector,
weights,
thresholds,
qubit_coords,
)
colorscale = _METRIC_CMAP[metric_key]
zmin = float(np.min(grid))
zmax = float(np.max(grid))
if zmax <= zmin:
zmax = zmin + 1e-9
fig = go.Figure(
data=go.Heatmap(
z=grid,
colorscale=colorscale,
zmin=zmin,
zmax=zmax,
colorbar={"title": metric_key},
hoverongaps=False,
)
)
for q, (rr, cc) in coords.items():
if 0 <= rr < cfg.rows and 0 <= cc < cfg.cols:
fig.add_annotation(
x=cc,
y=rr,
text=f"q{q}",
showarrow=False,
font={"size": 10, "color": "#0f172a"},
)
notes = []
skipped = int(meta.get("skipped_rows", 0))
if skipped:
logger.warning("Skipped %d malformed CSV rows while building heatmap.", skipped)
notes.append(f"Skipped malformed CSV rows: {skipped}")
calibration_note = str(meta.get("calibration_note", "") or "").strip()
if calibration_note:
notes.append(calibration_note)
if highlight_threshold is not None:
thr = float(np.clip(float(highlight_threshold), 0.0, 1e9))
highlighted = 0
for q, (rr, cc) in coords.items():
if 0 <= q < n_qubits and 0 <= rr < cfg.rows and 0 <= cc < cfg.cols:
if float(values[q]) >= thr:
highlighted += 1
fig.add_shape(
type="rect",
x0=cc - 0.5,
x1=cc + 0.5,
y0=rr - 0.5,
y1=rr + 0.5,
line={"color": "#f59e0b", "width": 2},
fillcolor="rgba(0,0,0,0)",
)
notes.append(f"Highlighted qubits (value >= {thr:.4g}): {highlighted}")
fig.update_layout(
title=_METRIC_TITLES[metric_key],
xaxis={"title": "Chip column", "tickmode": "array", "tickvals": list(range(cfg.cols))},
yaxis={
"title": "Chip row",
"tickmode": "array",
"tickvals": list(range(cfg.rows)),
"autorange": "reversed",
"scaleanchor": "x",
"scaleratio": 1,
},
margin={"l": 50, "r": 30, "t": 70, "b": 50},
dragmode="pan",
)
if notes:
fig.add_annotation(
text=" | ".join(notes),
xref="paper",
yref="paper",
x=0.5,
y=-0.17,
showarrow=False,
font={"size": 11, "color": "#92400e"},
)
return fig
def make_activity_heatmap(
csv_text: str,
n_qubits: int,
cfg: Optional[HeatmapConfig] = None,
qubit_coords: Optional[Dict[int, Tuple[int, int]]] = None,
) -> plt.Figure:
return make_metric_heatmap(
csv_text=csv_text,
n_qubits=n_qubits,
metric="activity_count",
cfg=cfg,
calibration_json="",
weights=None,
thresholds=None,
highlight_threshold=None,
qubit_coords=qubit_coords,
)