| """ |
| Tests for the state_store module. |
| |
| Covers: |
| * round-trip snapshot/checkout |
| * content-addressed dedup (same tensors -> same blob) |
| * async write-behind completion (queue drains) |
| * branch / log lineage walk |
| * gc removes only unreachable snapshots + blobs |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| from pathlib import Path |
|
|
| import pytest |
|
|
| torch = pytest.importorskip("torch") |
|
|
| from state_store import ( |
| StateStore, |
| snapshot, |
| checkout, |
| log, |
| diff, |
| branch, |
| gc, |
| ) |
| from state_store.store import hash_bytes |
|
|
|
|
| |
| |
| |
| class TinyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = torch.nn.Linear(4, 8, bias=True) |
| self.fc2 = torch.nn.Linear(8, 4, bias=True) |
|
|
| def forward(self, x): |
| return self.fc2(torch.relu(self.fc1(x))) |
|
|
|
|
| def _make_model_and_opt(seed: int = 0): |
| torch.manual_seed(seed) |
| model = TinyModel() |
| opt = torch.optim.SGD(model.parameters(), lr=0.1) |
| return model, opt |
|
|
|
|
| @pytest.fixture |
| def store(tmp_path): |
| |
| s = StateStore(root=tmp_path / "store", sync=True) |
| yield s |
| s.shutdown() |
|
|
|
|
| @pytest.fixture |
| def async_store(tmp_path): |
| s = StateStore(root=tmp_path / "async_store", sync=False) |
| yield s |
| s.shutdown() |
|
|
|
|
| |
| |
| |
| def test_snapshot_roundtrip(store): |
| m1, o1 = _make_model_and_opt(seed=1) |
| metrics = {"val_bpb": 1.777, "loss": 2.5, "step": 100} |
| h = snapshot(m1, o1, step=100, metrics=metrics, store=store) |
| assert isinstance(h, str) and len(h) >= 32 |
|
|
| |
| m2, o2 = _make_model_and_opt(seed=999) |
| for (n1, p1), (n2, p2) in zip(m1.named_parameters(), m2.named_parameters()): |
| assert not torch.equal(p1, p2), f"{n1}/{n2} should start different" |
|
|
| row = checkout(h, m2, o2, store=store) |
| assert row["step"] == 100 |
| assert row["metrics"]["val_bpb"] == 1.777 |
|
|
| for (n1, p1), (n2, p2) in zip(m1.named_parameters(), m2.named_parameters()): |
| assert torch.equal(p1.cpu(), p2.cpu()), f"param {n1} not restored" |
|
|
|
|
| |
| |
| |
| def test_content_addressed_dedup(store): |
| m, o = _make_model_and_opt(seed=42) |
| metrics = {"val_bpb": 2.0, "loss": 3.0} |
| h1 = snapshot(m, o, step=1, metrics=metrics, store=store) |
| h2 = snapshot(m, o, step=1, metrics=metrics, store=store) |
| |
| assert h1 == h2 |
|
|
| |
| |
| h3 = snapshot(m, o, step=2, metrics=metrics, store=store) |
| mf1 = json.loads(store.get_snapshot(h1)["manifest_json"]) |
| mf3 = json.loads(store.get_snapshot(h3)["manifest_json"]) |
| assert mf1["model"].keys() == mf3["model"].keys() |
| for k in mf1["model"]: |
| assert mf1["model"][k] == mf3["model"][k], f"blob hash changed for {k}" |
|
|
| |
| |
| unique_blob_hashes = set(mf1["model"].values()) | set(mf3["model"].values()) |
| for bh in unique_blob_hashes: |
| assert store.has_blob(bh), f"blob {bh} missing from store" |
|
|
|
|
| def test_snapshot_changes_when_weights_change(store): |
| m, o = _make_model_and_opt(seed=7) |
| metrics = {"val_bpb": 1.0} |
| h1 = snapshot(m, o, step=1, metrics=metrics, store=store) |
|
|
| with torch.no_grad(): |
| m.fc1.weight.add_(1.0) |
| h2 = snapshot(m, o, step=2, metrics=metrics, store=store) |
| assert h1 != h2 |
|
|
| d = diff(h1, h2, store=store) |
| assert "fc1.weight" in d["changed"] |
| |
| assert d["identical_blob_count"] >= 2 |
|
|
|
|
| |
| |
| |
| def test_async_writes_drain(async_store): |
| m, o = _make_model_and_opt(seed=3) |
| hashes = [] |
| for step in range(5): |
| with torch.no_grad(): |
| m.fc1.weight.add_(0.01) |
| hashes.append( |
| snapshot(m, o, step=step, metrics={"val_bpb": float(step)}, store=async_store) |
| ) |
| async_store.flush(timeout=15) |
| |
| for h in hashes: |
| row = async_store.get_snapshot(h) |
| assert row is not None, f"snapshot {h} not persisted" |
| rows = log(limit=10, store=async_store) |
| assert len(rows) == 5 |
|
|
|
|
| |
| |
| |
| def test_branch_and_log(store): |
| m, o = _make_model_and_opt(seed=2) |
| h1 = snapshot(m, o, step=1, metrics={"val_bpb": 3.0}, store=store) |
| with torch.no_grad(): |
| m.fc1.weight.add_(0.5) |
| h2 = snapshot(m, o, step=2, metrics={"val_bpb": 2.5}, parent_hash=h1, store=store) |
| with torch.no_grad(): |
| m.fc1.weight.add_(0.5) |
| h3 = snapshot(m, o, step=3, metrics={"val_bpb": 2.0}, parent_hash=h2, store=store) |
|
|
| branch("champ", h3, store=store) |
| assert store.resolve_ref("champ") == h3 |
|
|
| lin = log(limit=10, branch="champ", store=store) |
| assert [r["hash"] for r in lin] == [h3, h2, h1] |
|
|
|
|
| |
| |
| |
| def test_gc_removes_only_unreachable(store): |
| m, o = _make_model_and_opt(seed=5) |
| hashes = [] |
| parent = None |
| for step in range(6): |
| with torch.no_grad(): |
| m.fc1.weight.add_(0.1) |
| parent = snapshot( |
| m, o, step=step, metrics={"val_bpb": 5.0 - step}, |
| parent_hash=parent, store=store, |
| ) |
| hashes.append(parent) |
|
|
| branch("keep_me", hashes[2], store=store) |
|
|
| res = gc(keep_last=1, reachable_from="keep_me", store=store) |
| |
| kept = res["kept_snapshots"] |
| assert kept >= 3 |
| |
| assert store.resolve_ref("keep_me") == hashes[2] |
| |
| removed = set(res["removed_snapshots"]) |
| |
| assert hashes[-1] not in removed |
| |
| for h in res["removed_snapshots"]: |
| assert store.get_snapshot(h) is None |
| |
| for h in hashes[:3]: |
| row = store.get_snapshot(h) |
| assert row is not None |
| mf = json.loads(row["manifest_json"]) |
| for bh in mf["model"].values(): |
| assert store.has_blob(bh), f"blob {bh} gc'd but snapshot {h} still references it" |
|
|
|
|
| def test_gc_dry_run_does_not_delete(store): |
| m, o = _make_model_and_opt(seed=8) |
| parent = None |
| hashes = [] |
| for step in range(3): |
| with torch.no_grad(): |
| m.fc1.weight.add_(0.2) |
| parent = snapshot(m, o, step=step, metrics={"loss": 1.0 * step}, |
| parent_hash=parent, store=store) |
| hashes.append(parent) |
|
|
| res = gc(keep_last=0, dry_run=True, store=store) |
| |
| for h in hashes: |
| assert store.get_snapshot(h) is not None |
|
|
|
|
| |
| |
| |
| def test_hash_bytes_deterministic(): |
| a = hash_bytes(b"hello world") |
| b = hash_bytes(b"hello world") |
| c = hash_bytes(b"hello worlD") |
| assert a == b |
| assert a != c |
|
|