File size: 8,025 Bytes
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
906af9d
8c486a8
 
 
 
 
 
 
 
906af9d
 
 
 
 
 
 
 
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2dbaf9
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906af9d
 
 
 
8c486a8
 
 
 
 
 
 
 
 
 
 
906af9d
 
d2dbaf9
906af9d
8c486a8
80ef9e0
 
 
 
 
 
 
d2dbaf9
80ef9e0
 
 
 
d2dbaf9
 
 
 
8c486a8
 
 
 
 
 
d2dbaf9
 
8c486a8
d2dbaf9
 
 
8c486a8
d2dbaf9
 
 
 
 
 
 
 
 
8c486a8
d2dbaf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
d2dbaf9
906af9d
 
 
 
 
 
 
d2dbaf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""Snapshot persistence -- save, list, and select validated snapshots.

Validated snapshots are stored as frozen JSON under ``snapshots/<id>/spec.json``.
The store supports selection strategies for ``reset()`` to draw from a pool of
pre-validated snapshots rather than generating on-demand.
"""

from __future__ import annotations

import json
import logging
import random
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from open_range.protocols import SnapshotSpec

logger = logging.getLogger(__name__)


@dataclass(frozen=True, slots=True)
class StoredSnapshot:
    """A frozen snapshot plus its persisted identifier."""

    snapshot_id: str
    snapshot: SnapshotSpec


class SnapshotStore:
    """Persist and retrieve validated snapshot specs."""

    def __init__(self, store_dir: str = "snapshots") -> None:
        self.store_dir = Path(store_dir)
        self.store_dir.mkdir(parents=True, exist_ok=True)

    async def store(self, snapshot: SnapshotSpec, snapshot_id: str | None = None) -> str:
        """Save a validated snapshot to disk.

        Args:
            snapshot: The validated snapshot spec.
            snapshot_id: Optional explicit ID. Generated from topology if absent.

        Returns:
            The snapshot ID string.
        """
        if snapshot_id is None:
            vuln_types = [v.type for v in snapshot.truth_graph.vulns]
            snapshot_id = (
                f"snap_{'_'.join(vuln_types[:3])}"
                f"_{int(time.time())}"
            )

        snap_dir = self.store_dir / snapshot_id
        snap_dir.mkdir(parents=True, exist_ok=True)

        spec_path = snap_dir / "spec.json"
        spec_path.write_text(
            snapshot.model_dump_json(indent=2),
            encoding="utf-8",
        )

        # Write metadata sidecar for fast listing
        meta = self._metadata_from_snapshot(snapshot_id, snapshot)
        meta_path = snap_dir / "metadata.json"
        meta_path.write_text(json.dumps(meta, indent=2), encoding="utf-8")

        logger.info("Stored snapshot %s at %s", snapshot_id, snap_dir)
        return snapshot_id

    async def select(self, strategy: str = "latest") -> SnapshotSpec:
        """Select a snapshot from the store.

        Args:
            strategy: Selection strategy.
                - ``"latest"``: most recently stored snapshot
                - ``"random"``: uniformly random

        Returns:
            The selected SnapshotSpec.

        Raises:
            FileNotFoundError: If the store is empty.
        """
        return (await self.select_entry(strategy=strategy)).snapshot

    async def select_entry(self, strategy: str = "latest") -> StoredSnapshot:
        """Select a snapshot plus its persisted ID."""
        spec_files = sorted(self.store_dir.glob("*/spec.json"))
        if not spec_files:
            raise FileNotFoundError(
                f"No snapshots in store: {self.store_dir}"
            )

        if strategy == "random":
            chosen = random.choice(spec_files)
        else:  # latest -- sort by parent dir mtime
            chosen = max(spec_files, key=lambda p: p.stat().st_mtime)

        return StoredSnapshot(
            snapshot_id=chosen.parent.name,
            snapshot=self._load_spec(chosen),
        )

    async def list_entries(self) -> list[StoredSnapshot]:
        """Return every stored snapshot plus its persisted ID."""
        entries: list[StoredSnapshot] = []
        for spec_path in sorted(self.store_dir.glob("*/spec.json")):
            entries.append(
                StoredSnapshot(
                    snapshot_id=spec_path.parent.name,
                    snapshot=self._load_spec(spec_path),
                )
            )
        return entries

    async def count_entries(self) -> int:
        """Return canonical snapshot count based on persisted specs."""
        return len(await self.list_entries())

    async def list_snapshots(self) -> list[dict[str, Any]]:
        """List all snapshots with their metadata.

        Returns:
            List of metadata dicts, sorted by stored_at descending.
        """
        entries = await self.list_entries()
        spec_ids = {entry.snapshot_id for entry in entries}
        results: list[dict[str, Any]] = []
        for entry in entries:
            meta_path = self.store_dir / entry.snapshot_id / "metadata.json"
            existing_meta: dict[str, Any] | None = None
            try:
                if meta_path.exists():
                    loaded = json.loads(meta_path.read_text(encoding="utf-8"))
                    if isinstance(loaded, dict):
                        existing_meta = loaded
                    else:
                        logger.warning(
                            "Repairing metadata sidecar with non-object payload: %s",
                            meta_path,
                        )
            except (json.JSONDecodeError, OSError) as exc:
                logger.warning("Repairing corrupt metadata: %s (%s)", meta_path, exc)

            stored_at = existing_meta.get("stored_at") if existing_meta else None
            canonical = self._metadata_from_snapshot(
                entry.snapshot_id,
                entry.snapshot,
                stored_at=stored_at if isinstance(stored_at, (int, float)) else None,
            )
            results.append(canonical)

            if existing_meta != canonical:
                try:
                    meta_path.write_text(json.dumps(canonical, indent=2), encoding="utf-8")
                except OSError as exc:
                    logger.warning("Failed to repair metadata sidecar %s (%s)", meta_path, exc)

        for meta_path in self.store_dir.glob("*/metadata.json"):
            if meta_path.parent.name not in spec_ids:
                logger.warning("Ignoring orphan metadata without spec.json: %s", meta_path)

        results.sort(key=lambda m: m.get("stored_at", 0), reverse=True)
        return results

    async def get(self, snapshot_id: str) -> SnapshotSpec:
        """Load a specific snapshot by ID.

        Raises:
            FileNotFoundError: If the snapshot does not exist.
        """
        spec_path = self.store_dir / snapshot_id / "spec.json"
        if not spec_path.exists():
            raise FileNotFoundError(f"Snapshot not found: {snapshot_id}")
        return self._load_spec(spec_path)

    async def get_entry(self, snapshot_id: str) -> StoredSnapshot:
        """Load a specific snapshot plus its ID."""
        return StoredSnapshot(
            snapshot_id=snapshot_id,
            snapshot=await self.get(snapshot_id),
        )

    @staticmethod
    def _metadata_from_snapshot(
        snapshot_id: str,
        snapshot: SnapshotSpec,
        *,
        stored_at: float | None = None,
    ) -> dict[str, Any]:
        return {
            "snapshot_id": snapshot_id,
            "vuln_classes": [v.type for v in snapshot.truth_graph.vulns],
            "golden_path_steps": len(snapshot.golden_path),
            "flag_count": len(snapshot.flags),
            "npc_count": len(snapshot.npc_personas),
            "has_compose": bool(snapshot.compose),
            "has_payload_files": bool(snapshot.files),
            "live_validated": bool(snapshot.topology.get("live_validated", False)),
            "parent_snapshot_id": snapshot.lineage.parent_snapshot_id,
            "root_snapshot_id": snapshot.lineage.root_snapshot_id,
            "generation_depth": snapshot.lineage.generation_depth,
            "mutation_summary": list(snapshot.lineage.mutation_summary),
            "stored_at": float(time.time() if stored_at is None else stored_at),
        }

    @staticmethod
    def _load_spec(spec_path: Path) -> SnapshotSpec:
        try:
            raw = json.loads(spec_path.read_text(encoding="utf-8"))
            return SnapshotSpec.model_validate(raw)
        except Exception as exc:  # noqa: BLE001
            raise ValueError(f"invalid snapshot spec at {spec_path}: {exc}") from exc