| |
| from __future__ import annotations |
|
|
| import logging |
| from dataclasses import dataclass |
| from typing import Dict, Optional, Tuple, Any |
|
|
| import numpy as np |
| import matplotlib.pyplot as plt |
|
|
| from .metrics import compute_metrics_from_csv, MetricWeights, MetricThresholds |
|
|
| _HAS_PLOTLY = False |
| try: |
| import plotly.graph_objects as go |
|
|
| _HAS_PLOTLY = True |
| except Exception: |
| _HAS_PLOTLY = False |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class HeatmapConfig: |
| rows: int = 8 |
| cols: int = 8 |
| missing_value: float = 0.0 |
|
|
|
|
| def _default_qubit_coords(n_qubits: int, rows: int, cols: int) -> Dict[int, Tuple[int, int]]: |
| """ |
| Simple default mapping: q0.. mapped row-major across the chip grid. |
| Override later with user-uploaded mapping if needed. |
| """ |
| m: Dict[int, Tuple[int, int]] = {} |
| for q in range(n_qubits): |
| r = q // cols |
| c = q % cols |
| if r >= rows: |
| break |
| m[q] = (r, c) |
| return m |
|
|
|
|
| _METRIC_ALIASES = { |
| "activity": "activity_count", |
| "activity_count": "activity_count", |
| "activity_norm": "activity_norm", |
| "gate_error": "gate_error", |
| "readout_error": "readout_error", |
| "coherence_health": "coherence_health", |
| "decoherence_time": "coherence_health", |
| "decoherence_times": "coherence_health", |
| "decoherence": "decoherence_risk", |
| "decoherence_risk": "decoherence_risk", |
| "fidelity": "fidelity", |
| "state_fidelity": "state_fidelity", |
| "process_fidelity": "process_fidelity", |
| "composite": "composite_risk", |
| "composite_risk": "composite_risk", |
| } |
|
|
| _METRIC_TITLES = { |
| "activity_count": "Qubit activity count (from circuit CSV)", |
| "activity_norm": "Normalized qubit activity (0-1)", |
| "gate_error": "Gate error rate heatmap", |
| "readout_error": "Readout error rate heatmap", |
| "coherence_health": "Coherence health heatmap (long-lived qubits are deep blue)", |
| "decoherence_risk": "Decoherence risk heatmap", |
| "fidelity": "Fidelity heatmap", |
| "state_fidelity": "State fidelity heatmap", |
| "process_fidelity": "Process fidelity heatmap", |
| "composite_risk": "Composite reliability risk heatmap", |
| } |
|
|
| _METRIC_CMAP = { |
| "activity_count": "viridis", |
| "activity_norm": "viridis", |
| "gate_error": "Reds", |
| "readout_error": "Reds", |
| "coherence_health": "Blues", |
| "decoherence_risk": "inferno", |
| "fidelity": "Greens", |
| "state_fidelity": "Greens", |
| "process_fidelity": "Greens", |
| "composite_risk": "hot", |
| } |
|
|
|
|
| def _resolve_metric(metric: str) -> str: |
| m = (metric or "activity_count").strip().lower() |
| return _METRIC_ALIASES.get(m, "activity_count") |
|
|
|
|
| def plotly_available() -> bool: |
| return bool(_HAS_PLOTLY) |
|
|
|
|
| def _build_metric_grid( |
| csv_text: str, |
| n_qubits: int, |
| metric: str, |
| cfg: HeatmapConfig, |
| calibration_json: str, |
| state_vector: Optional[np.ndarray], |
| weights: Optional[MetricWeights], |
| thresholds: Optional[MetricThresholds], |
| qubit_coords: Optional[Dict[int, Tuple[int, int]]], |
| ) -> Tuple[str, np.ndarray, np.ndarray, Dict[str, Any], Dict[int, Tuple[int, int]]]: |
| metric_key = _resolve_metric(metric) |
| coords = qubit_coords or _default_qubit_coords(n_qubits, cfg.rows, cfg.cols) |
| grid = np.full((cfg.rows, cfg.cols), cfg.missing_value, dtype=float) |
| metrics, meta = compute_metrics_from_csv( |
| csv_text, |
| int(n_qubits), |
| calibration_json=calibration_json, |
| state_vector=state_vector, |
| weights=weights, |
| thresholds=thresholds, |
| ) |
| values = metrics[metric_key] |
| for q, (rr, cc) in coords.items(): |
| if 0 <= q < n_qubits and 0 <= rr < cfg.rows and 0 <= cc < cfg.cols: |
| grid[rr, cc] = values[q] |
| return metric_key, grid, values, meta, coords |
|
|
|
|
| def make_metric_heatmap( |
| csv_text: str, |
| n_qubits: int, |
| metric: str = "activity_count", |
| cfg: Optional[HeatmapConfig] = None, |
| calibration_json: str = "", |
| state_vector: Optional[np.ndarray] = None, |
| weights: Optional[MetricWeights] = None, |
| thresholds: Optional[MetricThresholds] = None, |
| highlight_threshold: Optional[float] = None, |
| qubit_coords: Optional[Dict[int, Tuple[int, int]]] = None, |
| ) -> plt.Figure: |
| """ |
| Builds a heatmap for a selected metric computed from circuit CSV and optional calibration JSON. |
| """ |
| cfg = cfg or HeatmapConfig() |
| metric_key, grid, values, meta, coords = _build_metric_grid( |
| csv_text, |
| int(n_qubits), |
| str(metric), |
| cfg, |
| calibration_json, |
| state_vector, |
| weights, |
| thresholds, |
| qubit_coords, |
| ) |
|
|
| |
| fig, ax = plt.subplots(figsize=(6, 5)) |
| im = ax.imshow(grid, interpolation="nearest", cmap=_METRIC_CMAP[metric_key]) |
| ax.set_title(_METRIC_TITLES[metric_key]) |
| ax.set_xlabel("Chip column") |
| ax.set_ylabel("Chip row") |
| fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
|
|
| skipped = int(meta.get("skipped_rows", 0)) |
| if skipped: |
| logger.warning("Skipped %d malformed CSV rows while building heatmap.", skipped) |
| ax.text( |
| 0.01, |
| 1.02, |
| f"Skipped {skipped} malformed CSV row(s)", |
| transform=ax.transAxes, |
| fontsize=8, |
| color="#b91c1c", |
| ha="left", |
| va="bottom", |
| ) |
|
|
| calibration_note = str(meta.get("calibration_note", "") or "").strip() |
| if calibration_note: |
| ax.text( |
| 0.99, |
| 1.02, |
| calibration_note, |
| transform=ax.transAxes, |
| fontsize=8, |
| color="#92400e", |
| ha="right", |
| va="bottom", |
| ) |
|
|
| |
| for q, (rr, cc) in coords.items(): |
| if 0 <= rr < cfg.rows and 0 <= cc < cfg.cols: |
| ax.text(cc, rr, f"q{q}", ha="center", va="center", fontsize=9) |
|
|
| if highlight_threshold is not None: |
| thr = float(np.clip(float(highlight_threshold), 0.0, 1e9)) |
| highlighted = 0 |
| for q, (rr, cc) in coords.items(): |
| if 0 <= q < n_qubits and 0 <= rr < cfg.rows and 0 <= cc < cfg.cols: |
| if float(values[q]) >= thr: |
| highlighted += 1 |
| ax.add_patch( |
| plt.Rectangle( |
| (cc - 0.5, rr - 0.5), |
| 1.0, |
| 1.0, |
| fill=False, |
| linewidth=2.0, |
| edgecolor="#f59e0b", |
| ) |
| ) |
| ax.text( |
| 0.5, |
| -0.12, |
| f"Highlighted qubits (value >= {thr:.4g}): {highlighted}", |
| transform=ax.transAxes, |
| fontsize=8, |
| color="#92400e", |
| ha="center", |
| va="top", |
| ) |
|
|
| ax.set_xticks(range(cfg.cols)) |
| ax.set_yticks(range(cfg.rows)) |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def make_metric_heatmap_plotly( |
| csv_text: str, |
| n_qubits: int, |
| metric: str = "activity_count", |
| cfg: Optional[HeatmapConfig] = None, |
| calibration_json: str = "", |
| state_vector: Optional[np.ndarray] = None, |
| weights: Optional[MetricWeights] = None, |
| thresholds: Optional[MetricThresholds] = None, |
| highlight_threshold: Optional[float] = None, |
| qubit_coords: Optional[Dict[int, Tuple[int, int]]] = None, |
| ) -> Any: |
| """ |
| Builds an interactive Plotly heatmap for zoom/pan exploration. |
| """ |
| if not _HAS_PLOTLY: |
| raise RuntimeError("Plotly is not available in this environment.") |
|
|
| cfg = cfg or HeatmapConfig() |
| metric_key, grid, values, meta, coords = _build_metric_grid( |
| csv_text, |
| int(n_qubits), |
| str(metric), |
| cfg, |
| calibration_json, |
| state_vector, |
| weights, |
| thresholds, |
| qubit_coords, |
| ) |
|
|
| colorscale = _METRIC_CMAP[metric_key] |
| zmin = float(np.min(grid)) |
| zmax = float(np.max(grid)) |
| if zmax <= zmin: |
| zmax = zmin + 1e-9 |
|
|
| fig = go.Figure( |
| data=go.Heatmap( |
| z=grid, |
| colorscale=colorscale, |
| zmin=zmin, |
| zmax=zmax, |
| colorbar={"title": metric_key}, |
| hoverongaps=False, |
| ) |
| ) |
|
|
| for q, (rr, cc) in coords.items(): |
| if 0 <= rr < cfg.rows and 0 <= cc < cfg.cols: |
| fig.add_annotation( |
| x=cc, |
| y=rr, |
| text=f"q{q}", |
| showarrow=False, |
| font={"size": 10, "color": "#0f172a"}, |
| ) |
|
|
| notes = [] |
| skipped = int(meta.get("skipped_rows", 0)) |
| if skipped: |
| logger.warning("Skipped %d malformed CSV rows while building heatmap.", skipped) |
| notes.append(f"Skipped malformed CSV rows: {skipped}") |
|
|
| calibration_note = str(meta.get("calibration_note", "") or "").strip() |
| if calibration_note: |
| notes.append(calibration_note) |
|
|
| if highlight_threshold is not None: |
| thr = float(np.clip(float(highlight_threshold), 0.0, 1e9)) |
| highlighted = 0 |
| for q, (rr, cc) in coords.items(): |
| if 0 <= q < n_qubits and 0 <= rr < cfg.rows and 0 <= cc < cfg.cols: |
| if float(values[q]) >= thr: |
| highlighted += 1 |
| fig.add_shape( |
| type="rect", |
| x0=cc - 0.5, |
| x1=cc + 0.5, |
| y0=rr - 0.5, |
| y1=rr + 0.5, |
| line={"color": "#f59e0b", "width": 2}, |
| fillcolor="rgba(0,0,0,0)", |
| ) |
| notes.append(f"Highlighted qubits (value >= {thr:.4g}): {highlighted}") |
|
|
| fig.update_layout( |
| title=_METRIC_TITLES[metric_key], |
| xaxis={"title": "Chip column", "tickmode": "array", "tickvals": list(range(cfg.cols))}, |
| yaxis={ |
| "title": "Chip row", |
| "tickmode": "array", |
| "tickvals": list(range(cfg.rows)), |
| "autorange": "reversed", |
| "scaleanchor": "x", |
| "scaleratio": 1, |
| }, |
| margin={"l": 50, "r": 30, "t": 70, "b": 50}, |
| dragmode="pan", |
| ) |
|
|
| if notes: |
| fig.add_annotation( |
| text=" | ".join(notes), |
| xref="paper", |
| yref="paper", |
| x=0.5, |
| y=-0.17, |
| showarrow=False, |
| font={"size": 11, "color": "#92400e"}, |
| ) |
|
|
| return fig |
|
|
|
|
| def make_activity_heatmap( |
| csv_text: str, |
| n_qubits: int, |
| cfg: Optional[HeatmapConfig] = None, |
| qubit_coords: Optional[Dict[int, Tuple[int, int]]] = None, |
| ) -> plt.Figure: |
| return make_metric_heatmap( |
| csv_text=csv_text, |
| n_qubits=n_qubits, |
| metric="activity_count", |
| cfg=cfg, |
| calibration_json="", |
| weights=None, |
| thresholds=None, |
| highlight_threshold=None, |
| qubit_coords=qubit_coords, |
| ) |
|
|