feather-a10g-large-runtime / overlay /tests /test_state_store.py
icarus112's picture
Update Feather a10g-large training runtime image
c475135 verified
"""
Tests for the state_store module.
Covers:
* round-trip snapshot/checkout
* content-addressed dedup (same tensors -> same blob)
* async write-behind completion (queue drains)
* branch / log lineage walk
* gc removes only unreachable snapshots + blobs
"""
from __future__ import annotations
import json
import os
from pathlib import Path
import pytest
torch = pytest.importorskip("torch")
from state_store import (
StateStore,
snapshot,
checkout,
log,
diff,
branch,
gc,
)
from state_store.store import hash_bytes
# ---------------------------------------------------------------------------
# Tiny model + optimizer for deterministic tests
# ---------------------------------------------------------------------------
class TinyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(4, 8, bias=True)
self.fc2 = torch.nn.Linear(8, 4, bias=True)
def forward(self, x):
return self.fc2(torch.relu(self.fc1(x)))
def _make_model_and_opt(seed: int = 0):
torch.manual_seed(seed)
model = TinyModel()
opt = torch.optim.SGD(model.parameters(), lr=0.1)
return model, opt
@pytest.fixture
def store(tmp_path):
# Sync store simplifies assertions; async path is covered separately below.
s = StateStore(root=tmp_path / "store", sync=True)
yield s
s.shutdown()
@pytest.fixture
def async_store(tmp_path):
s = StateStore(root=tmp_path / "async_store", sync=False)
yield s
s.shutdown()
# ---------------------------------------------------------------------------
# Round-trip
# ---------------------------------------------------------------------------
def test_snapshot_roundtrip(store):
m1, o1 = _make_model_and_opt(seed=1)
metrics = {"val_bpb": 1.777, "loss": 2.5, "step": 100}
h = snapshot(m1, o1, step=100, metrics=metrics, store=store)
assert isinstance(h, str) and len(h) >= 32
# Fresh model with different init -> checkout must restore weights.
m2, o2 = _make_model_and_opt(seed=999)
for (n1, p1), (n2, p2) in zip(m1.named_parameters(), m2.named_parameters()):
assert not torch.equal(p1, p2), f"{n1}/{n2} should start different"
row = checkout(h, m2, o2, store=store)
assert row["step"] == 100
assert row["metrics"]["val_bpb"] == 1.777
for (n1, p1), (n2, p2) in zip(m1.named_parameters(), m2.named_parameters()):
assert torch.equal(p1.cpu(), p2.cpu()), f"param {n1} not restored"
# ---------------------------------------------------------------------------
# Dedup: snapshotting the same model twice yields identical manifest entries
# ---------------------------------------------------------------------------
def test_content_addressed_dedup(store):
m, o = _make_model_and_opt(seed=42)
metrics = {"val_bpb": 2.0, "loss": 3.0}
h1 = snapshot(m, o, step=1, metrics=metrics, store=store)
h2 = snapshot(m, o, step=1, metrics=metrics, store=store)
# Same step + state + metrics => identical snapshot hash.
assert h1 == h2
# Even if the step changes, every per-tensor blob hash must be identical
# because the weights themselves haven't changed.
h3 = snapshot(m, o, step=2, metrics=metrics, store=store)
mf1 = json.loads(store.get_snapshot(h1)["manifest_json"])
mf3 = json.loads(store.get_snapshot(h3)["manifest_json"])
assert mf1["model"].keys() == mf3["model"].keys()
for k in mf1["model"]:
assert mf1["model"][k] == mf3["model"][k], f"blob hash changed for {k}"
# Every referenced blob must be reachable via the store (works for both
# legacy per-file layout and Phase-1 chunked/packfile layout).
unique_blob_hashes = set(mf1["model"].values()) | set(mf3["model"].values())
for bh in unique_blob_hashes:
assert store.has_blob(bh), f"blob {bh} missing from store"
def test_snapshot_changes_when_weights_change(store):
m, o = _make_model_and_opt(seed=7)
metrics = {"val_bpb": 1.0}
h1 = snapshot(m, o, step=1, metrics=metrics, store=store)
with torch.no_grad():
m.fc1.weight.add_(1.0) # mutate
h2 = snapshot(m, o, step=2, metrics=metrics, store=store)
assert h1 != h2
d = diff(h1, h2, store=store)
assert "fc1.weight" in d["changed"]
# fc2 weight/bias unchanged -> appears in identical_blob_count bucket.
assert d["identical_blob_count"] >= 2
# ---------------------------------------------------------------------------
# Async write-behind
# ---------------------------------------------------------------------------
def test_async_writes_drain(async_store):
m, o = _make_model_and_opt(seed=3)
hashes = []
for step in range(5):
with torch.no_grad():
m.fc1.weight.add_(0.01)
hashes.append(
snapshot(m, o, step=step, metrics={"val_bpb": float(step)}, store=async_store)
)
async_store.flush(timeout=15)
# All rows visible.
for h in hashes:
row = async_store.get_snapshot(h)
assert row is not None, f"snapshot {h} not persisted"
rows = log(limit=10, store=async_store)
assert len(rows) == 5
# ---------------------------------------------------------------------------
# Branch + log lineage
# ---------------------------------------------------------------------------
def test_branch_and_log(store):
m, o = _make_model_and_opt(seed=2)
h1 = snapshot(m, o, step=1, metrics={"val_bpb": 3.0}, store=store)
with torch.no_grad():
m.fc1.weight.add_(0.5)
h2 = snapshot(m, o, step=2, metrics={"val_bpb": 2.5}, parent_hash=h1, store=store)
with torch.no_grad():
m.fc1.weight.add_(0.5)
h3 = snapshot(m, o, step=3, metrics={"val_bpb": 2.0}, parent_hash=h2, store=store)
branch("champ", h3, store=store)
assert store.resolve_ref("champ") == h3
lin = log(limit=10, branch="champ", store=store)
assert [r["hash"] for r in lin] == [h3, h2, h1]
# ---------------------------------------------------------------------------
# GC
# ---------------------------------------------------------------------------
def test_gc_removes_only_unreachable(store):
m, o = _make_model_and_opt(seed=5)
hashes = []
parent = None
for step in range(6):
with torch.no_grad():
m.fc1.weight.add_(0.1)
parent = snapshot(
m, o, step=step, metrics={"val_bpb": 5.0 - step},
parent_hash=parent, store=store,
)
hashes.append(parent)
branch("keep_me", hashes[2], store=store)
res = gc(keep_last=1, reachable_from="keep_me", store=store)
# With keep_last=1, last snapshot is kept; plus lineage from keep_me (h0..h2).
kept = res["kept_snapshots"]
assert kept >= 3 # h0, h1, h2 are reachable from keep_me
# keep_me head must still resolve.
assert store.resolve_ref("keep_me") == hashes[2]
# h3, h4 may have been removed (they're not reachable and not in keep_last=1 window).
removed = set(res["removed_snapshots"])
# The last (newest) snapshot is in the keep_last=1 window, so NOT removed.
assert hashes[-1] not in removed
# Everything kept must still be readable.
for h in res["removed_snapshots"]:
assert store.get_snapshot(h) is None
# Blobs for reachable snapshots must still exist on disk.
for h in hashes[:3]:
row = store.get_snapshot(h)
assert row is not None
mf = json.loads(row["manifest_json"])
for bh in mf["model"].values():
assert store.has_blob(bh), f"blob {bh} gc'd but snapshot {h} still references it"
def test_gc_dry_run_does_not_delete(store):
m, o = _make_model_and_opt(seed=8)
parent = None
hashes = []
for step in range(3):
with torch.no_grad():
m.fc1.weight.add_(0.2)
parent = snapshot(m, o, step=step, metrics={"loss": 1.0 * step},
parent_hash=parent, store=store)
hashes.append(parent)
res = gc(keep_last=0, dry_run=True, store=store)
# Dry-run: snapshots still present in DB.
for h in hashes:
assert store.get_snapshot(h) is not None
# ---------------------------------------------------------------------------
# Hash utility sanity
# ---------------------------------------------------------------------------
def test_hash_bytes_deterministic():
a = hash_bytes(b"hello world")
b = hash_bytes(b"hello world")
c = hash_bytes(b"hello worlD")
assert a == b
assert a != c