QureadAI / quread /trends.py
hchevva's picture
Upload 2 files
48f3eae verified
from __future__ import annotations
import json
import csv
import io
from typing import Any, Dict, List, Tuple
import numpy as np
from .metrics import (
MetricThresholds,
MetricWeights,
compute_metrics_from_csv,
)
_TREND_METRIC_ALIASES = {
"composite": "composite_risk",
"composite_risk": "composite_risk",
"gate_error": "gate_error",
"readout_error": "readout_error",
"decoherence_risk": "decoherence_risk",
"fidelity": "fidelity",
"state_fidelity": "state_fidelity",
"process_fidelity": "process_fidelity",
"coherence_health": "coherence_health",
}
_HIGHER_IS_RISK = {
"composite_risk",
"gate_error",
"readout_error",
"decoherence_risk",
}
def _resolve_metric(metric: str) -> str:
key = str(metric or "composite_risk").strip().lower()
return _TREND_METRIC_ALIASES.get(key, "composite_risk")
def _higher_is_risk(metric_key: str) -> bool:
return str(metric_key) in _HIGHER_IS_RISK
def _risk_transform(series: np.ndarray, metric_key: str) -> np.ndarray:
values = np.asarray(series, dtype=float)
if _higher_is_risk(metric_key):
return np.clip(values, 0.0, 1.0)
return np.clip(1.0 - values, 0.0, 1.0)
def _slope(values: np.ndarray) -> float:
v = np.asarray(values, dtype=float).reshape(-1)
if v.size <= 1:
return 0.0
return float((v[-1] - v[0]) / max(1, v.size - 1))
def _snapshot_timestamp(snapshot: Dict[str, Any], idx: int) -> str:
ts = snapshot.get("timestamp")
if ts is None:
ts = snapshot.get("ts")
if ts is None:
ts = snapshot.get("date")
text = str(ts).strip() if ts is not None else ""
if text:
return text
return f"snapshot_{idx + 1}"
def _normalize_snapshot(snapshot: Dict[str, Any], idx: int) -> Dict[str, Any] | None:
if not isinstance(snapshot, dict):
return None
ts = _snapshot_timestamp(snapshot, idx)
if isinstance(snapshot.get("calibration"), dict):
payload = snapshot["calibration"]
elif isinstance(snapshot.get("qubits"), dict):
payload = {"qubits": snapshot["qubits"]}
else:
qubit_like = {
str(k): v
for k, v in snapshot.items()
if str(k).strip().isdigit() and isinstance(v, dict)
}
if not qubit_like:
return None
payload = {"qubits": qubit_like}
return {
"timestamp": ts,
"calibration_json": json.dumps(payload),
}
def parse_calibration_snapshots_text(text: str) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
raw = str(text or "").strip()
if not raw:
return [], {"format": "empty", "parsed": 0, "skipped": 0}
snapshots: List[Dict[str, Any]] = []
skipped = 0
fmt = "unknown"
try:
parsed = json.loads(raw)
if isinstance(parsed, dict) and isinstance(parsed.get("snapshots"), list):
iterable = parsed["snapshots"]
fmt = "json:snapshots"
elif isinstance(parsed, list):
iterable = parsed
fmt = "json:list"
elif isinstance(parsed, dict):
iterable = [parsed]
fmt = "json:single"
else:
iterable = []
skipped += 1
except Exception:
fmt = "jsonl"
iterable = []
for line in raw.splitlines():
chunk = line.strip()
if not chunk:
continue
try:
iterable.append(json.loads(chunk))
except Exception:
skipped += 1
for idx, snap in enumerate(iterable):
normalized = _normalize_snapshot(snap, idx)
if normalized is None:
skipped += 1
continue
snapshots.append(normalized)
return snapshots, {"format": fmt, "parsed": len(snapshots), "skipped": int(skipped)}
def compute_metric_trends(
csv_text: str,
n_qubits: int,
snapshots_text: str,
*,
metric: str = "composite_risk",
state_vector: np.ndarray | None = None,
weights: MetricWeights | None = None,
thresholds: MetricThresholds | None = None,
) -> Tuple[np.ndarray, List[str], List[Dict[str, float]], Dict[str, Any]]:
metric_key = _resolve_metric(metric)
snapshots, meta = parse_calibration_snapshots_text(snapshots_text)
if not snapshots:
raise ValueError("No valid calibration snapshots found.")
series = []
labels: List[str] = []
for snap in snapshots:
labels.append(str(snap["timestamp"]))
metrics, _ = compute_metrics_from_csv(
csv_text,
int(n_qubits),
calibration_json=str(snap["calibration_json"]),
state_vector=state_vector,
weights=weights,
thresholds=thresholds,
)
series.append(np.asarray(metrics[metric_key], dtype=float))
arr = np.vstack(series)
latest = arr[-1]
baseline = arr[0]
risk_series = _risk_transform(arr, metric_key)
latest_risk = risk_series[-1]
baseline_risk = risk_series[0]
ranking: List[Dict[str, float]] = []
for q in range(int(n_qubits)):
raw_delta = float(latest[q] - baseline[q])
risk_delta = float(latest_risk[q] - baseline_risk[q])
risk_slope = _slope(risk_series[:, q])
ranking.append(
{
"qubit": int(q),
"latest": float(latest[q]),
"baseline": float(baseline[q]),
"delta": raw_delta,
"latest_risk": float(latest_risk[q]),
"baseline_risk": float(baseline_risk[q]),
"risk_delta": risk_delta,
"risk_slope": risk_slope,
}
)
ranking.sort(key=lambda r: r["latest_risk"], reverse=True)
meta["metric"] = metric_key
meta["points"] = int(arr.shape[0])
meta["higher_is_risk"] = bool(_higher_is_risk(metric_key))
meta["risk_mode"] = "higher_is_risk" if _higher_is_risk(metric_key) else "lower_is_risk"
return arr, labels, ranking, meta
def compute_drift_alerts(
ranking: List[Dict[str, float]],
*,
warning_threshold: float,
critical_threshold: float,
delta_warning: float,
delta_critical: float,
slope_warning: float,
slope_critical: float,
) -> List[Dict[str, Any]]:
warn = float(np.clip(float(warning_threshold), 0.0, 1.0))
crit = float(np.clip(float(critical_threshold), 0.0, 1.0))
if crit < warn:
crit = warn
d_warn = float(max(0.0, delta_warning))
d_crit = float(max(d_warn, delta_critical))
s_warn = float(max(0.0, slope_warning))
s_crit = float(max(s_warn, slope_critical))
alerts: List[Dict[str, Any]] = []
for row in ranking:
q = int(row["qubit"])
latest_risk = float(np.clip(float(row["latest_risk"]), 0.0, 1.0))
risk_delta = float(row["risk_delta"])
risk_slope = float(row["risk_slope"])
triggers: List[str] = []
severity = 0
if latest_risk >= crit:
severity = max(severity, 2)
triggers.append(f"latest_risk>=critical({latest_risk:.3f}>={crit:.3f})")
elif latest_risk >= warn:
severity = max(severity, 1)
triggers.append(f"latest_risk>=warning({latest_risk:.3f}>={warn:.3f})")
if risk_delta >= d_crit:
severity = max(severity, 2)
triggers.append(f"risk_delta>=critical({risk_delta:.3f}>={d_crit:.3f})")
elif risk_delta >= d_warn:
severity = max(severity, 1)
triggers.append(f"risk_delta>=warning({risk_delta:.3f}>={d_warn:.3f})")
if risk_slope >= s_crit:
severity = max(severity, 2)
triggers.append(f"risk_slope>=critical({risk_slope:.3f}>={s_crit:.3f})")
elif risk_slope >= s_warn:
severity = max(severity, 1)
triggers.append(f"risk_slope>=warning({risk_slope:.3f}>={s_warn:.3f})")
level = "critical" if severity == 2 else ("warning" if severity == 1 else "ok")
alerts.append(
{
"qubit": q,
"level": level,
"severity": severity,
"latest_risk": latest_risk,
"risk_delta": risk_delta,
"risk_slope": risk_slope,
"triggers": triggers,
}
)
alerts.sort(
key=lambda r: (
int(r["severity"]),
float(r["latest_risk"]),
float(r["risk_delta"]),
float(r["risk_slope"]),
),
reverse=True,
)
return alerts
def compute_snapshot_delta(
series: np.ndarray,
metric_key: str,
*,
from_index: int,
to_index: int,
) -> Tuple[np.ndarray, np.ndarray, int, int]:
arr = np.asarray(series, dtype=float)
if arr.ndim != 2:
raise ValueError("series must be a 2D array of shape [snapshots, qubits].")
points = int(arr.shape[0])
if points < 2:
raise ValueError("Need at least 2 snapshots to compute delta.")
i = int(np.clip(int(from_index), 0, points - 1))
j = int(np.clip(int(to_index), 0, points - 1))
if i == j:
if j == points - 1:
i = max(0, j - 1)
else:
j = min(points - 1, i + 1)
raw_delta = arr[j] - arr[i]
risk_series = _risk_transform(arr, _resolve_metric(metric_key))
risk_delta = risk_series[j] - risk_series[i]
return np.asarray(raw_delta, dtype=float), np.asarray(risk_delta, dtype=float), int(i), int(j)
def alerts_to_csv(alerts: List[Dict[str, Any]]) -> str:
out = io.StringIO()
writer = csv.writer(out)
writer.writerow(["qubit", "level", "latest_risk", "risk_delta", "risk_slope", "triggers"])
for row in alerts:
triggers = row.get("triggers") or []
writer.writerow(
[
int(row.get("qubit", -1)),
str(row.get("level", "unknown")),
f"{float(row.get('latest_risk', 0.0)):.6g}",
f"{float(row.get('risk_delta', 0.0)):.6g}",
f"{float(row.get('risk_slope', 0.0)):.6g}",
"; ".join(str(t) for t in triggers),
]
)
return out.getvalue()