hchevva commited on
Commit
d0e98cb
·
verified ·
1 Parent(s): f10100d

Upload heatmap.py

Browse files
Files changed (1) hide show
  1. quread/heatmap.py +165 -16
quread/heatmap.py CHANGED
@@ -3,13 +3,21 @@ from __future__ import annotations
3
 
4
  import logging
5
  from dataclasses import dataclass
6
- from typing import Dict, Optional, Tuple
7
 
8
  import numpy as np
9
  import matplotlib.pyplot as plt
10
 
11
  from .metrics import compute_metrics_from_csv, MetricWeights, MetricThresholds
12
 
 
 
 
 
 
 
 
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
@@ -85,6 +93,39 @@ def _resolve_metric(metric: str) -> str:
85
  return _METRIC_ALIASES.get(m, "activity_count")
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def make_metric_heatmap(
89
  csv_text: str,
90
  n_qubits: int,
@@ -101,24 +142,17 @@ def make_metric_heatmap(
101
  Builds a heatmap for a selected metric computed from circuit CSV and optional calibration JSON.
102
  """
103
  cfg = cfg or HeatmapConfig()
104
- coords = qubit_coords or _default_qubit_coords(n_qubits, cfg.rows, cfg.cols)
105
- metric_key = _resolve_metric(metric)
106
-
107
- grid = np.full((cfg.rows, cfg.cols), cfg.missing_value, dtype=float)
108
- metrics, meta = compute_metrics_from_csv(
109
  csv_text,
110
  int(n_qubits),
111
- calibration_json=calibration_json,
112
- state_vector=state_vector,
113
- weights=weights,
114
- thresholds=thresholds,
 
 
 
115
  )
116
- values = metrics[metric_key]
117
-
118
- # place into chip grid
119
- for q, (rr, cc) in coords.items():
120
- if 0 <= q < n_qubits and 0 <= rr < cfg.rows and 0 <= cc < cfg.cols:
121
- grid[rr, cc] = values[q]
122
 
123
  # plot
124
  fig, ax = plt.subplots(figsize=(6, 5))
@@ -194,6 +228,121 @@ def make_metric_heatmap(
194
  return fig
195
 
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  def make_activity_heatmap(
198
  csv_text: str,
199
  n_qubits: int,
 
3
 
4
  import logging
5
  from dataclasses import dataclass
6
+ from typing import Dict, Optional, Tuple, Any
7
 
8
  import numpy as np
9
  import matplotlib.pyplot as plt
10
 
11
  from .metrics import compute_metrics_from_csv, MetricWeights, MetricThresholds
12
 
13
+ _HAS_PLOTLY = False
14
+ try:
15
+ import plotly.graph_objects as go
16
+
17
+ _HAS_PLOTLY = True
18
+ except Exception:
19
+ _HAS_PLOTLY = False
20
+
21
  logger = logging.getLogger(__name__)
22
 
23
 
 
93
  return _METRIC_ALIASES.get(m, "activity_count")
94
 
95
 
96
+ def plotly_available() -> bool:
97
+ return bool(_HAS_PLOTLY)
98
+
99
+
100
+ def _build_metric_grid(
101
+ csv_text: str,
102
+ n_qubits: int,
103
+ metric: str,
104
+ cfg: HeatmapConfig,
105
+ calibration_json: str,
106
+ state_vector: Optional[np.ndarray],
107
+ weights: Optional[MetricWeights],
108
+ thresholds: Optional[MetricThresholds],
109
+ qubit_coords: Optional[Dict[int, Tuple[int, int]]],
110
+ ) -> Tuple[str, np.ndarray, np.ndarray, Dict[str, Any], Dict[int, Tuple[int, int]]]:
111
+ metric_key = _resolve_metric(metric)
112
+ coords = qubit_coords or _default_qubit_coords(n_qubits, cfg.rows, cfg.cols)
113
+ grid = np.full((cfg.rows, cfg.cols), cfg.missing_value, dtype=float)
114
+ metrics, meta = compute_metrics_from_csv(
115
+ csv_text,
116
+ int(n_qubits),
117
+ calibration_json=calibration_json,
118
+ state_vector=state_vector,
119
+ weights=weights,
120
+ thresholds=thresholds,
121
+ )
122
+ values = metrics[metric_key]
123
+ for q, (rr, cc) in coords.items():
124
+ if 0 <= q < n_qubits and 0 <= rr < cfg.rows and 0 <= cc < cfg.cols:
125
+ grid[rr, cc] = values[q]
126
+ return metric_key, grid, values, meta, coords
127
+
128
+
129
  def make_metric_heatmap(
130
  csv_text: str,
131
  n_qubits: int,
 
142
  Builds a heatmap for a selected metric computed from circuit CSV and optional calibration JSON.
143
  """
144
  cfg = cfg or HeatmapConfig()
145
+ metric_key, grid, values, meta, coords = _build_metric_grid(
 
 
 
 
146
  csv_text,
147
  int(n_qubits),
148
+ str(metric),
149
+ cfg,
150
+ calibration_json,
151
+ state_vector,
152
+ weights,
153
+ thresholds,
154
+ qubit_coords,
155
  )
 
 
 
 
 
 
156
 
157
  # plot
158
  fig, ax = plt.subplots(figsize=(6, 5))
 
228
  return fig
229
 
230
 
231
+ def make_metric_heatmap_plotly(
232
+ csv_text: str,
233
+ n_qubits: int,
234
+ metric: str = "activity_count",
235
+ cfg: Optional[HeatmapConfig] = None,
236
+ calibration_json: str = "",
237
+ state_vector: Optional[np.ndarray] = None,
238
+ weights: Optional[MetricWeights] = None,
239
+ thresholds: Optional[MetricThresholds] = None,
240
+ highlight_threshold: Optional[float] = None,
241
+ qubit_coords: Optional[Dict[int, Tuple[int, int]]] = None,
242
+ ) -> Any:
243
+ """
244
+ Builds an interactive Plotly heatmap for zoom/pan exploration.
245
+ """
246
+ if not _HAS_PLOTLY:
247
+ raise RuntimeError("Plotly is not available in this environment.")
248
+
249
+ cfg = cfg or HeatmapConfig()
250
+ metric_key, grid, values, meta, coords = _build_metric_grid(
251
+ csv_text,
252
+ int(n_qubits),
253
+ str(metric),
254
+ cfg,
255
+ calibration_json,
256
+ state_vector,
257
+ weights,
258
+ thresholds,
259
+ qubit_coords,
260
+ )
261
+
262
+ colorscale = _METRIC_CMAP[metric_key]
263
+ zmin = float(np.min(grid))
264
+ zmax = float(np.max(grid))
265
+ if zmax <= zmin:
266
+ zmax = zmin + 1e-9
267
+
268
+ fig = go.Figure(
269
+ data=go.Heatmap(
270
+ z=grid,
271
+ colorscale=colorscale,
272
+ zmin=zmin,
273
+ zmax=zmax,
274
+ colorbar={"title": metric_key},
275
+ hoverongaps=False,
276
+ )
277
+ )
278
+
279
+ for q, (rr, cc) in coords.items():
280
+ if 0 <= rr < cfg.rows and 0 <= cc < cfg.cols:
281
+ fig.add_annotation(
282
+ x=cc,
283
+ y=rr,
284
+ text=f"q{q}",
285
+ showarrow=False,
286
+ font={"size": 10, "color": "#0f172a"},
287
+ )
288
+
289
+ notes = []
290
+ skipped = int(meta.get("skipped_rows", 0))
291
+ if skipped:
292
+ logger.warning("Skipped %d malformed CSV rows while building heatmap.", skipped)
293
+ notes.append(f"Skipped malformed CSV rows: {skipped}")
294
+
295
+ calibration_note = str(meta.get("calibration_note", "") or "").strip()
296
+ if calibration_note:
297
+ notes.append(calibration_note)
298
+
299
+ if highlight_threshold is not None:
300
+ thr = float(np.clip(float(highlight_threshold), 0.0, 1e9))
301
+ highlighted = 0
302
+ for q, (rr, cc) in coords.items():
303
+ if 0 <= q < n_qubits and 0 <= rr < cfg.rows and 0 <= cc < cfg.cols:
304
+ if float(values[q]) >= thr:
305
+ highlighted += 1
306
+ fig.add_shape(
307
+ type="rect",
308
+ x0=cc - 0.5,
309
+ x1=cc + 0.5,
310
+ y0=rr - 0.5,
311
+ y1=rr + 0.5,
312
+ line={"color": "#f59e0b", "width": 2},
313
+ fillcolor="rgba(0,0,0,0)",
314
+ )
315
+ notes.append(f"Highlighted qubits (value >= {thr:.4g}): {highlighted}")
316
+
317
+ fig.update_layout(
318
+ title=_METRIC_TITLES[metric_key],
319
+ xaxis={"title": "Chip column", "tickmode": "array", "tickvals": list(range(cfg.cols))},
320
+ yaxis={
321
+ "title": "Chip row",
322
+ "tickmode": "array",
323
+ "tickvals": list(range(cfg.rows)),
324
+ "autorange": "reversed",
325
+ "scaleanchor": "x",
326
+ "scaleratio": 1,
327
+ },
328
+ margin={"l": 50, "r": 30, "t": 70, "b": 50},
329
+ dragmode="pan",
330
+ )
331
+
332
+ if notes:
333
+ fig.add_annotation(
334
+ text=" | ".join(notes),
335
+ xref="paper",
336
+ yref="paper",
337
+ x=0.5,
338
+ y=-0.17,
339
+ showarrow=False,
340
+ font={"size": 11, "color": "#92400e"},
341
+ )
342
+
343
+ return fig
344
+
345
+
346
  def make_activity_heatmap(
347
  csv_text: str,
348
  n_qubits: int,