QureadAI / tests /test_qec_bundle.py
hchevva's picture
Upload 2 files
b3588b5 verified
import json
import os
import pathlib
import tempfile
import unittest
import zipfile
from quread.qec_bundle import build_sample_qec_bundle_bytes, parse_qec_bundle, sample_qec_bundle_filename
def _write_bundle(manifest: dict, artifacts=None) -> str:
artifacts = dict(artifacts or {})
fd, path = tempfile.mkstemp(prefix="qec_bundle_", suffix=".zip")
os.close(fd)
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
zf.writestr("manifest.json", json.dumps(manifest))
for name, content in artifacts.items():
zf.writestr(name, content)
return path
def _valid_manifest() -> dict:
return {
"bundle_version": "1.0",
"source": "nvidia_ising_decoding",
"code_family": "surface_code",
"experiment_name": "surface_d13_demo",
"distance": 13,
"n_rounds": 104,
"basis": "X",
"rotation": "O1",
"noise_model_label": "public_default",
"generated_by": "unit-test",
"timestamp": "2026-04-15T10:00:00Z",
"notes": "demo bundle",
"model": {
"variant": "fast",
"model_id": 1,
"checkpoint_name": "Ising-Decoder-SurfaceCode-1-Fast.pt",
},
"decoders": {
"baseline": {
"name": "pymatching",
"ler": 0.0123,
"latency_ms": 4.8,
"syndrome_density_before": 0.031,
"syndrome_density_after": 0.031,
"logical_failures": 123,
"num_samples": 10000,
},
"ai_predecoder_plus_baseline": {
"name": "ising_predecoder_plus_pymatching",
"ler": 0.0104,
"latency_ms": 3.2,
"syndrome_density_before": 0.031,
"syndrome_density_after": 0.011,
"logical_failures": 104,
"num_samples": 10000,
},
},
"artifacts": [
{"path": "artifacts/run.log", "kind": "log"},
{"path": "artifacts/model.onnx", "kind": "onnx"},
],
}
class QECBundleParserTest(unittest.TestCase):
def test_sample_bundle_bytes_parse_successfully(self):
parsed = parse_qec_bundle(
{"data": build_sample_qec_bundle_bytes(), "name": sample_qec_bundle_filename()}
)
self.assertEqual(parsed["source_name"], sample_qec_bundle_filename())
self.assertEqual(parsed["manifest"]["experiment_name"], "surface_d13_public_demo")
self.assertEqual(len(parsed["artifact_rows"]), 5)
self.assertIn("artifacts/run.log", parsed["preview_text"])
self.assertEqual(parsed["warnings"], [])
def test_valid_bundle_parses_and_normalizes(self):
path = _write_bundle(
_valid_manifest(),
{
"artifacts/run.log": "decoder run log",
"artifacts/model.onnx": "fake-onnx",
},
)
try:
parsed = parse_qec_bundle(path)
self.assertEqual(parsed["manifest"]["code_family"], "surface_code")
self.assertEqual(parsed["manifest"]["source"], "nvidia_ising_decoding")
self.assertEqual(len(parsed["artifact_rows"]), 2)
self.assertIn("manifest.json", parsed["preview_text"])
self.assertEqual(parsed["warnings"], [])
finally:
os.remove(path)
def test_missing_manifest_fails(self):
fd, path = tempfile.mkstemp(prefix="qec_bundle_missing_", suffix=".zip")
os.close(fd)
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
zf.writestr("artifacts/run.log", "hello")
try:
with self.assertRaisesRegex(ValueError, "manifest.json"):
parse_qec_bundle(path)
finally:
os.remove(path)
def test_missing_required_manifest_fields_fail(self):
manifest = _valid_manifest()
del manifest["distance"]
path = _write_bundle(manifest)
try:
with self.assertRaisesRegex(ValueError, "distance"):
parse_qec_bundle(path)
finally:
os.remove(path)
def test_unsupported_code_family_fails(self):
manifest = _valid_manifest()
manifest["code_family"] = "color_code"
path = _write_bundle(manifest)
try:
with self.assertRaisesRegex(ValueError, "surface_code"):
parse_qec_bundle(path)
finally:
os.remove(path)
def test_missing_optional_artifact_becomes_warning(self):
path = _write_bundle(
_valid_manifest(),
{
"artifacts/run.log": "decoder run log",
},
)
try:
parsed = parse_qec_bundle(path)
self.assertEqual(len(parsed["warnings"]), 1)
self.assertIn("Listed artifact missing from bundle", parsed["warnings"][0])
statuses = {row["path"]: row["status"] for row in parsed["artifact_rows"]}
self.assertEqual(statuses["artifacts/model.onnx"], "missing")
finally:
os.remove(path)
def test_mismatched_num_samples_adds_warning(self):
manifest = _valid_manifest()
manifest["decoders"]["ai_predecoder_plus_baseline"]["num_samples"] = 8000
path = _write_bundle(manifest)
try:
parsed = parse_qec_bundle(path)
self.assertTrue(any("sample counts differ" in warning for warning in parsed["warnings"]))
finally:
os.remove(path)