Codeseys's picture
Wave 7+8+9: spikes 006/007/008 — close vision-validation gaps V2/V5/V8
57af35d
|
Raw
History Blame Contribute Delete
1.98 kB
# Spike 008 — Streaming DiLoCo outer-loop smoke
**Closes**: V2 (DiLoCo "deferred to v0.2") in `docs/VISION_VALIDATION.md`.
## Goal
Bolt the DiLoCo outer-loop pseudo-gradient sync onto the framework using
`torchft.local_sgd.DiLoCo` (see `docs/adrs/ADR-003-diloco-impl.md`).
Verify:
1. Two in-process replicas converge to identical parameters after outer sync.
2. Outer Nesterov momentum is actually populated (i.e. the outer optimizer
ran).
3. The pseudo-gradient sign convention is what we expect (sign flip detected
by an explicit unit test).
4. Importing torchft does not regress Spike 005's existing 38 tests.
Single-process, no NCCL. Mock `Manager.allreduce` does real cross-replica
averaging through a shared buffer.
## Files
- `composer_diloco.py``make_diloco_outer_loop(...)` wrapper around
`torchft.local_sgd.DiLoCo`. Documents the sign convention.
- `tests/test_diloco_smoke.py` — 3 acceptance tests.
## Acceptance
| Criterion | Status |
|---|---|
| 2 replicas converge after 2 outer rounds | ✓ test 1 |
| Nesterov momentum state populated | ✓ test 1 |
| Sync fires once per outer round per replica | ✓ test 1 |
| Pseudo-gradient sign convention verified | ✓ test 2 |
| No regression in Spike 005 imports | ✓ test 3 |
| Spike 005's 38 tests still pass after this wave | (verified separately) |
## Future work (v0.2 Streaming DiLoCo)
- `fragment_sync_delay > 0` requires CUDA streams. Spike 008 uses
`fragment_sync_delay=0` (vanilla DiLoCo) for the smoke.
- Multiple fragments via `model_fragments=[frag_0, frag_1, ...]` configured
by `make_diloco_outer_loop()` but not exercised in the smoke.
- Real torch.distributed backend (NCCL) for multi-node training is
one config switch away (replace mock `Manager` with real `torchft.Manager`).
## Cost / time
- Pure CPU, single process, no GPU.
- Tests run in <2 seconds total.
## Dependencies added
- `torchft-nightly` (BSD-3, Meta-maintained, `pip install torchft-nightly`)