Spaces:
Runtime error
Runtime error
File size: 8,025 Bytes
8c486a8 906af9d 8c486a8 906af9d 8c486a8 d2dbaf9 8c486a8 906af9d 8c486a8 906af9d d2dbaf9 906af9d 8c486a8 80ef9e0 d2dbaf9 80ef9e0 d2dbaf9 8c486a8 d2dbaf9 8c486a8 d2dbaf9 8c486a8 d2dbaf9 8c486a8 d2dbaf9 8c486a8 d2dbaf9 906af9d d2dbaf9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | """Snapshot persistence -- save, list, and select validated snapshots.
Validated snapshots are stored as frozen JSON under ``snapshots/<id>/spec.json``.
The store supports selection strategies for ``reset()`` to draw from a pool of
pre-validated snapshots rather than generating on-demand.
"""
from __future__ import annotations
import json
import logging
import random
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from open_range.protocols import SnapshotSpec
logger = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class StoredSnapshot:
"""A frozen snapshot plus its persisted identifier."""
snapshot_id: str
snapshot: SnapshotSpec
class SnapshotStore:
"""Persist and retrieve validated snapshot specs."""
def __init__(self, store_dir: str = "snapshots") -> None:
self.store_dir = Path(store_dir)
self.store_dir.mkdir(parents=True, exist_ok=True)
async def store(self, snapshot: SnapshotSpec, snapshot_id: str | None = None) -> str:
"""Save a validated snapshot to disk.
Args:
snapshot: The validated snapshot spec.
snapshot_id: Optional explicit ID. Generated from topology if absent.
Returns:
The snapshot ID string.
"""
if snapshot_id is None:
vuln_types = [v.type for v in snapshot.truth_graph.vulns]
snapshot_id = (
f"snap_{'_'.join(vuln_types[:3])}"
f"_{int(time.time())}"
)
snap_dir = self.store_dir / snapshot_id
snap_dir.mkdir(parents=True, exist_ok=True)
spec_path = snap_dir / "spec.json"
spec_path.write_text(
snapshot.model_dump_json(indent=2),
encoding="utf-8",
)
# Write metadata sidecar for fast listing
meta = self._metadata_from_snapshot(snapshot_id, snapshot)
meta_path = snap_dir / "metadata.json"
meta_path.write_text(json.dumps(meta, indent=2), encoding="utf-8")
logger.info("Stored snapshot %s at %s", snapshot_id, snap_dir)
return snapshot_id
async def select(self, strategy: str = "latest") -> SnapshotSpec:
"""Select a snapshot from the store.
Args:
strategy: Selection strategy.
- ``"latest"``: most recently stored snapshot
- ``"random"``: uniformly random
Returns:
The selected SnapshotSpec.
Raises:
FileNotFoundError: If the store is empty.
"""
return (await self.select_entry(strategy=strategy)).snapshot
async def select_entry(self, strategy: str = "latest") -> StoredSnapshot:
"""Select a snapshot plus its persisted ID."""
spec_files = sorted(self.store_dir.glob("*/spec.json"))
if not spec_files:
raise FileNotFoundError(
f"No snapshots in store: {self.store_dir}"
)
if strategy == "random":
chosen = random.choice(spec_files)
else: # latest -- sort by parent dir mtime
chosen = max(spec_files, key=lambda p: p.stat().st_mtime)
return StoredSnapshot(
snapshot_id=chosen.parent.name,
snapshot=self._load_spec(chosen),
)
async def list_entries(self) -> list[StoredSnapshot]:
"""Return every stored snapshot plus its persisted ID."""
entries: list[StoredSnapshot] = []
for spec_path in sorted(self.store_dir.glob("*/spec.json")):
entries.append(
StoredSnapshot(
snapshot_id=spec_path.parent.name,
snapshot=self._load_spec(spec_path),
)
)
return entries
async def count_entries(self) -> int:
"""Return canonical snapshot count based on persisted specs."""
return len(await self.list_entries())
async def list_snapshots(self) -> list[dict[str, Any]]:
"""List all snapshots with their metadata.
Returns:
List of metadata dicts, sorted by stored_at descending.
"""
entries = await self.list_entries()
spec_ids = {entry.snapshot_id for entry in entries}
results: list[dict[str, Any]] = []
for entry in entries:
meta_path = self.store_dir / entry.snapshot_id / "metadata.json"
existing_meta: dict[str, Any] | None = None
try:
if meta_path.exists():
loaded = json.loads(meta_path.read_text(encoding="utf-8"))
if isinstance(loaded, dict):
existing_meta = loaded
else:
logger.warning(
"Repairing metadata sidecar with non-object payload: %s",
meta_path,
)
except (json.JSONDecodeError, OSError) as exc:
logger.warning("Repairing corrupt metadata: %s (%s)", meta_path, exc)
stored_at = existing_meta.get("stored_at") if existing_meta else None
canonical = self._metadata_from_snapshot(
entry.snapshot_id,
entry.snapshot,
stored_at=stored_at if isinstance(stored_at, (int, float)) else None,
)
results.append(canonical)
if existing_meta != canonical:
try:
meta_path.write_text(json.dumps(canonical, indent=2), encoding="utf-8")
except OSError as exc:
logger.warning("Failed to repair metadata sidecar %s (%s)", meta_path, exc)
for meta_path in self.store_dir.glob("*/metadata.json"):
if meta_path.parent.name not in spec_ids:
logger.warning("Ignoring orphan metadata without spec.json: %s", meta_path)
results.sort(key=lambda m: m.get("stored_at", 0), reverse=True)
return results
async def get(self, snapshot_id: str) -> SnapshotSpec:
"""Load a specific snapshot by ID.
Raises:
FileNotFoundError: If the snapshot does not exist.
"""
spec_path = self.store_dir / snapshot_id / "spec.json"
if not spec_path.exists():
raise FileNotFoundError(f"Snapshot not found: {snapshot_id}")
return self._load_spec(spec_path)
async def get_entry(self, snapshot_id: str) -> StoredSnapshot:
"""Load a specific snapshot plus its ID."""
return StoredSnapshot(
snapshot_id=snapshot_id,
snapshot=await self.get(snapshot_id),
)
@staticmethod
def _metadata_from_snapshot(
snapshot_id: str,
snapshot: SnapshotSpec,
*,
stored_at: float | None = None,
) -> dict[str, Any]:
return {
"snapshot_id": snapshot_id,
"vuln_classes": [v.type for v in snapshot.truth_graph.vulns],
"golden_path_steps": len(snapshot.golden_path),
"flag_count": len(snapshot.flags),
"npc_count": len(snapshot.npc_personas),
"has_compose": bool(snapshot.compose),
"has_payload_files": bool(snapshot.files),
"live_validated": bool(snapshot.topology.get("live_validated", False)),
"parent_snapshot_id": snapshot.lineage.parent_snapshot_id,
"root_snapshot_id": snapshot.lineage.root_snapshot_id,
"generation_depth": snapshot.lineage.generation_depth,
"mutation_summary": list(snapshot.lineage.mutation_summary),
"stored_at": float(time.time() if stored_at is None else stored_at),
}
@staticmethod
def _load_spec(spec_path: Path) -> SnapshotSpec:
try:
raw = json.loads(spec_path.read_text(encoding="utf-8"))
return SnapshotSpec.model_validate(raw)
except Exception as exc: # noqa: BLE001
raise ValueError(f"invalid snapshot spec at {spec_path}: {exc}") from exc
|