""" Tests for the state_store module. Covers: * round-trip snapshot/checkout * content-addressed dedup (same tensors -> same blob) * async write-behind completion (queue drains) * branch / log lineage walk * gc removes only unreachable snapshots + blobs """ from __future__ import annotations import json import os from pathlib import Path import pytest torch = pytest.importorskip("torch") from state_store import ( StateStore, snapshot, checkout, log, diff, branch, gc, ) from state_store.store import hash_bytes # --------------------------------------------------------------------------- # Tiny model + optimizer for deterministic tests # --------------------------------------------------------------------------- class TinyModel(torch.nn.Module): def __init__(self): super().__init__() self.fc1 = torch.nn.Linear(4, 8, bias=True) self.fc2 = torch.nn.Linear(8, 4, bias=True) def forward(self, x): return self.fc2(torch.relu(self.fc1(x))) def _make_model_and_opt(seed: int = 0): torch.manual_seed(seed) model = TinyModel() opt = torch.optim.SGD(model.parameters(), lr=0.1) return model, opt @pytest.fixture def store(tmp_path): # Sync store simplifies assertions; async path is covered separately below. s = StateStore(root=tmp_path / "store", sync=True) yield s s.shutdown() @pytest.fixture def async_store(tmp_path): s = StateStore(root=tmp_path / "async_store", sync=False) yield s s.shutdown() # --------------------------------------------------------------------------- # Round-trip # --------------------------------------------------------------------------- def test_snapshot_roundtrip(store): m1, o1 = _make_model_and_opt(seed=1) metrics = {"val_bpb": 1.777, "loss": 2.5, "step": 100} h = snapshot(m1, o1, step=100, metrics=metrics, store=store) assert isinstance(h, str) and len(h) >= 32 # Fresh model with different init -> checkout must restore weights. m2, o2 = _make_model_and_opt(seed=999) for (n1, p1), (n2, p2) in zip(m1.named_parameters(), m2.named_parameters()): assert not torch.equal(p1, p2), f"{n1}/{n2} should start different" row = checkout(h, m2, o2, store=store) assert row["step"] == 100 assert row["metrics"]["val_bpb"] == 1.777 for (n1, p1), (n2, p2) in zip(m1.named_parameters(), m2.named_parameters()): assert torch.equal(p1.cpu(), p2.cpu()), f"param {n1} not restored" # --------------------------------------------------------------------------- # Dedup: snapshotting the same model twice yields identical manifest entries # --------------------------------------------------------------------------- def test_content_addressed_dedup(store): m, o = _make_model_and_opt(seed=42) metrics = {"val_bpb": 2.0, "loss": 3.0} h1 = snapshot(m, o, step=1, metrics=metrics, store=store) h2 = snapshot(m, o, step=1, metrics=metrics, store=store) # Same step + state + metrics => identical snapshot hash. assert h1 == h2 # Even if the step changes, every per-tensor blob hash must be identical # because the weights themselves haven't changed. h3 = snapshot(m, o, step=2, metrics=metrics, store=store) mf1 = json.loads(store.get_snapshot(h1)["manifest_json"]) mf3 = json.loads(store.get_snapshot(h3)["manifest_json"]) assert mf1["model"].keys() == mf3["model"].keys() for k in mf1["model"]: assert mf1["model"][k] == mf3["model"][k], f"blob hash changed for {k}" # Every referenced blob must be reachable via the store (works for both # legacy per-file layout and Phase-1 chunked/packfile layout). unique_blob_hashes = set(mf1["model"].values()) | set(mf3["model"].values()) for bh in unique_blob_hashes: assert store.has_blob(bh), f"blob {bh} missing from store" def test_snapshot_changes_when_weights_change(store): m, o = _make_model_and_opt(seed=7) metrics = {"val_bpb": 1.0} h1 = snapshot(m, o, step=1, metrics=metrics, store=store) with torch.no_grad(): m.fc1.weight.add_(1.0) # mutate h2 = snapshot(m, o, step=2, metrics=metrics, store=store) assert h1 != h2 d = diff(h1, h2, store=store) assert "fc1.weight" in d["changed"] # fc2 weight/bias unchanged -> appears in identical_blob_count bucket. assert d["identical_blob_count"] >= 2 # --------------------------------------------------------------------------- # Async write-behind # --------------------------------------------------------------------------- def test_async_writes_drain(async_store): m, o = _make_model_and_opt(seed=3) hashes = [] for step in range(5): with torch.no_grad(): m.fc1.weight.add_(0.01) hashes.append( snapshot(m, o, step=step, metrics={"val_bpb": float(step)}, store=async_store) ) async_store.flush(timeout=15) # All rows visible. for h in hashes: row = async_store.get_snapshot(h) assert row is not None, f"snapshot {h} not persisted" rows = log(limit=10, store=async_store) assert len(rows) == 5 # --------------------------------------------------------------------------- # Branch + log lineage # --------------------------------------------------------------------------- def test_branch_and_log(store): m, o = _make_model_and_opt(seed=2) h1 = snapshot(m, o, step=1, metrics={"val_bpb": 3.0}, store=store) with torch.no_grad(): m.fc1.weight.add_(0.5) h2 = snapshot(m, o, step=2, metrics={"val_bpb": 2.5}, parent_hash=h1, store=store) with torch.no_grad(): m.fc1.weight.add_(0.5) h3 = snapshot(m, o, step=3, metrics={"val_bpb": 2.0}, parent_hash=h2, store=store) branch("champ", h3, store=store) assert store.resolve_ref("champ") == h3 lin = log(limit=10, branch="champ", store=store) assert [r["hash"] for r in lin] == [h3, h2, h1] # --------------------------------------------------------------------------- # GC # --------------------------------------------------------------------------- def test_gc_removes_only_unreachable(store): m, o = _make_model_and_opt(seed=5) hashes = [] parent = None for step in range(6): with torch.no_grad(): m.fc1.weight.add_(0.1) parent = snapshot( m, o, step=step, metrics={"val_bpb": 5.0 - step}, parent_hash=parent, store=store, ) hashes.append(parent) branch("keep_me", hashes[2], store=store) res = gc(keep_last=1, reachable_from="keep_me", store=store) # With keep_last=1, last snapshot is kept; plus lineage from keep_me (h0..h2). kept = res["kept_snapshots"] assert kept >= 3 # h0, h1, h2 are reachable from keep_me # keep_me head must still resolve. assert store.resolve_ref("keep_me") == hashes[2] # h3, h4 may have been removed (they're not reachable and not in keep_last=1 window). removed = set(res["removed_snapshots"]) # The last (newest) snapshot is in the keep_last=1 window, so NOT removed. assert hashes[-1] not in removed # Everything kept must still be readable. for h in res["removed_snapshots"]: assert store.get_snapshot(h) is None # Blobs for reachable snapshots must still exist on disk. for h in hashes[:3]: row = store.get_snapshot(h) assert row is not None mf = json.loads(row["manifest_json"]) for bh in mf["model"].values(): assert store.has_blob(bh), f"blob {bh} gc'd but snapshot {h} still references it" def test_gc_dry_run_does_not_delete(store): m, o = _make_model_and_opt(seed=8) parent = None hashes = [] for step in range(3): with torch.no_grad(): m.fc1.weight.add_(0.2) parent = snapshot(m, o, step=step, metrics={"loss": 1.0 * step}, parent_hash=parent, store=store) hashes.append(parent) res = gc(keep_last=0, dry_run=True, store=store) # Dry-run: snapshots still present in DB. for h in hashes: assert store.get_snapshot(h) is not None # --------------------------------------------------------------------------- # Hash utility sanity # --------------------------------------------------------------------------- def test_hash_bytes_deterministic(): a = hash_bytes(b"hello world") b = hash_bytes(b"hello world") c = hash_bytes(b"hello worlD") assert a == b assert a != c