composer-replication-framework / docs /research /DILOCO_RECONNAISSANCE.md
Codeseys's picture
Wave 7: Phase 2-4 of deep work loop — backlog, parallel research, three ADRs
ac4bfb4
|
Raw
History Blame Contribute Delete
17.9 kB
# DiLoCo Reference Implementation Reconnaissance
**Date:** 2026-05-25
**Purpose:** Pick ONE PyTorch reference implementation of (Streaming) DiLoCo to bolt onto
the composer-replication-framework outer-loop optimizer. Feeds ADR-003.
**Bias:** simple + working > fancy + theoretically-better. Library > research codebase.
---
## TL;DR — Recommendation
**Use `meta-pytorch/torchft`'s `torchft.local_sgd.DiLoCo` context manager.**
It is a maintained library (not a research codebase), BSD-3 licensed, supports both
vanilla DiLoCo and Streaming DiLoCo through one class, and — critically — is unit-testable
in a single process by passing a `MagicMock(Manager)` whose `allreduce` returns a `_DummyWork`.
Their own `torchft/local_sgd_test.py` already demonstrates the exact pattern Spike 008 needs.
The Streaming DiLoCo paper (Liu et al. 2025, arXiv:2501.18512) has no separate community
implementation — torchft *is* the reference implementation as of mid-2026. PrimeIntellect's
two repos are either too minimal (`diloco_simple`, no LICENSE, NCCL-locked, no Streaming) or
deprecated (`OpenDiLoCo`, hivemind-based, "no longer maintained" per its own README).
---
## Candidates Audited (primary sources only)
### A1. PrimeIntellect-ai/diloco_simple
- URL: https://github.com/PrimeIntellect-ai/diloco_simple
- License: **NONE** (no LICENSE file in repo — confirmed via `git clone` + `ls`).
All-rights-reserved by default under copyright law. **Cannot legally vendor or fork.**
- Last commit: **2024-05-31** (`be38ec4 add weight decay`).
- Activity: 8 commits total, ever. Two main authors. Effectively abandoned.
- Shape: single 180-LOC research script (`pure_torch_diloco.py`), pedagogical demo.
- Streaming DiLoCo? **No.** Vanilla DiLoCo only.
- Distributed: **Hard-coded NCCL via `torchrun`** + `init_process_group(backend="nccl")`.
Pulls in `wandb`, `transformers`, HuggingFace `datasets`, `cyclopts`, and trains a
full LlamaForCausalLM on C4. Not a library — a benchmark script.
- Verdict: **REJECT.** No license, no Streaming, no library API, NCCL-only, deps on
HF + wandb just to run. Useful as an *algorithm reference*, not as code to depend on.
### A2. PrimeIntellect-ai/OpenDiLoCo
- URL: https://github.com/PrimeIntellect-ai/OpenDiloco
- License: present (Apache-2.0 typical, not re-verified — moot, see below).
- Status: **Officially deprecated.** README first paragraph:
> "**Important Notice**: OpenDiLoCo is no longer maintained. For our production-ready
> distributed training solution, please check out `prime`."
- Built on: `hivemind` (DHT-based decentralized training). Multi-machine only.
- Streaming DiLoCo? No.
- Verdict: **REJECT.** Deprecated by its authors. Hivemind dependency would force us to
set up DHT initial peers just to run a unit test.
### A3. PrimeIntellect-ai/prime (a.k.a. INTELLECT-1 framework)
- URL: https://github.com/PrimeIntellect-ai/prime — note: the GitHub org now uses this
repo for their CLI/SDK; the original training framework was rebranded.
- The actual INTELLECT-1 training code uses an `ElasticDeviceMesh` abstraction and is
a full distributed training stack, not an algorithm library.
- Verdict: **REJECT.** Production framework, not a drop-in library. Coupling a 1.5k-LOC
fault-tolerant elastic mesh into our test framework is the opposite of "simple + working".
### A4. DeepMind reference implementation (Douillard et al., arXiv:2311.08105)
- **No public reference implementation exists.** The DiLoCo paper is algorithm-only.
Confirmed: paper has no associated GitHub link in arXiv abstract or PDF; HuggingFace
papers page links no code. DeepMind has not open-sourced their internal trainer.
- Verdict: **N/A — does not exist.**
### A5. meta-pytorch/torchft ← **CHOSEN**
- URL: https://github.com/meta-pytorch/torchft
- License: **BSD 3-Clause** (verified: `head -5 LICENSE` → "BSD 3-Clause License").
- Last commit on main: **2026-04-03** (HEAD `7eb7087 Add torchcomms ProcessGroup shim
for fault-tolerant reconfiguration`).
- Activity: 312 commits, multiple Meta contributors, recent commits across 2025 and 2026,
active CI, nightly PyPI builds at https://pypi.org/project/torchft-nightly/.
- Shape: **library**, not a research codebase. `torchft/` is a proper Python package with
`local_sgd.py`, `manager.py`, `process_group.py`, `local_sgd_test.py` (real pytest unit
tests), pyproject.toml, BSD-3.
- Streaming DiLoCo? **Yes** — the `DiLoCo` class is itself a Streaming DiLoCo
generalization (`fragment_sync_delay`, `fragment_update_alpha`); pass a single-element
`model_fragments=[model]` for vanilla DiLoCo.
- Source comment confirms: `"""... DiLoCo paper: https://arxiv.org/pdf/2311.08105 /
Streaming DiLoCo paper: https://arxiv.org/pdf/2501.18512 """`
---
## Deep Dive: torchft (the chosen one)
### (1) Repo metadata
| Field | Value |
|---|---|
| URL | https://github.com/meta-pytorch/torchft |
| License | BSD 3-Clause |
| HEAD commit | `7eb7087` (2026-04-03) |
| Total commits on main | 312 |
| Activity level | **Active** — commits in 2025 + 2026, Meta-maintained, PyPI nightly builds |
| Distribution | `pip install torchft-nightly` (prebuilt wheels) **OR** install from source (requires Rust + protobuf-compiler + maturin — only because of the Lighthouse/process-group Rust ext, not the algorithm code) |
| Python | `requires-python = ">=3.8"`; `torch>=2.7` per `pyproject.toml` |
### (2) Exact API / extension point
The integration target is `torchft/local_sgd.py`. Two relevant classes:
```python
# Public class — drop-in context manager
class DiLoCo:
def __init__(
self,
manager: Manager, # we mock this
model_fragments: List[nn.Module], # [model] for vanilla DiLoCo
inner_optimizer: optim.Optimizer,
outer_optimizer: optim.Optimizer | list[optim.Optimizer],
sync_every: int, # N inner steps
backup_device: Optional[torch.device] = None,
pin_memory: bool = True,
use_bucketization: bool = False,
bucket_cap_mb: Optional[int] = None,
should_quantize: bool = False,
fragment_sync_delay: int = 0, # τ in Streaming DiLoCo paper
fragment_update_alpha: float = 0.0,
) -> None: ...
```
The **pseudo-gradient** is computed in `_StreamingDiLoCoFragment._save_grads()`
(`torchft/local_sgd.py` line 324):
```python
def _save_grads(self) -> None:
"""Saves pseudo-gradients of the parameters"""
with torch.no_grad():
for name, p in self._model_fragment.named_parameters():
local_param = p.to_local() if isinstance(p, DTensor) else p
pseudogradient = self.original_parameters[name].to(p.device) - local_param
self._grads[name] = pseudogradient
```
Note the **sign**: `original − local` (i.e. `θ_initial − θ_local`). When this is later
copied into `p.grad` via `_set_grads`, an SGD step `p ← p − lr · grad` becomes
`p ← θ_initial − lr · (θ_initial − θ_local)` = a step *toward* `θ_local`. Our spec
says δ = θ_local − θ_initial; torchft uses the negation. Either convention works as
long as the outer optimizer's lr sign is consistent — torchft uses positive `outer_lr`
(e.g. 0.7) and SGD which subtracts the grad, so the math nets out. **Be careful when
unit-testing the sign in Spike 008.**
The **outer Nesterov step** is in `_StreamingDiLoCoFragment.perform_sync()` (line 423):
```python
if should_commit:
self._set_grads() # write pseudogradient into p.grad
self._outer_optimizer.step() # Nesterov SGD step (user-provided)
self.save_parameters()
self._merge_parameters()
self._outer_optimizer.zero_grad()
```
The Nesterov-ness lives in the user-provided outer optimizer, e.g.:
```python
outer_optimizer = torch.optim.SGD(model.parameters(), lr=0.7, momentum=0.9, nesterov=True)
```
This matches the DiLoCo paper exactly (Douillard §3 specifies Nesterov momentum outer).
The cross-replica all-reduce happens in `_average_grads()` (called from `prepare_sync`)
via `self._manager.allreduce(...)` — which is the seam we mock for single-process tests.
### (3) torch.distributed dependency for testing?
**No, not for unit tests.** The `Manager` is mockable. From `torchft/local_sgd_test.py`:
```python
from unittest.mock import create_autospec, MagicMock
from torchft.manager import Manager
from torchft.work import _DummyWork
def create_manager() -> MagicMock:
manager = create_autospec(Manager)
manager.errored.return_value = None
def mock_allreduce(tensor: torch.Tensor, should_quantize: bool = False):
return _DummyWork(tensor) # returns the same tensor unchanged
manager.allreduce.side_effect = mock_allreduce
return manager
```
This bypasses NCCL/Gloo entirely. `_DummyWork` just wraps the tensor and returns it as
the "all-reduced" result, so a single-process test with `world_size=1` works directly,
and a 2-replica test is achieved by running two `DiLoCo` instances with two model
copies in the same process and a `mock_allreduce` that *averages* the two tensors
manually before returning. (Their `test_bucketization_correctness` does exactly this.)
For real distributed runs torchft uses Gloo or NCCL via `torchft.process_group`
(reconfigurable PGs that wrap `torch.distributed`). We do not need this for Spike 008.
### (4) Library, research codebase, or paper-companion?
**Library.** Strong evidence:
- Proper Python package layout (`torchft/__init__.py`, modules per concern).
- Real unit tests (`*_test.py` per module) — not "run this script" demos.
- BSD-3-Clause LICENSE (vs. diloco_simple having none, signaling "personal demo").
- Nightly PyPI distribution (`torchft-nightly`) with prebuilt wheels.
- Documentation site at https://pytorch.org/torchft.
- `meta-pytorch` org — Meta-internally maintained; lives next to `torchtitan`.
- README explicitly: *"torchft is designed to provide the primitives required to
implement fault tolerance in any application/train script"* — i.e. a building block.
Only friction: installing **from source** needs Rust (pyo3 + maturin) and
protobuf-compiler. This is for the Rust Lighthouse/process-group extension which we
**do not need** for Spike 008's mock-based tests. Two clean options:
- (a) `pip install torchft-nightly` — uses prebuilt wheel, no Rust toolchain needed.
- (b) Vendor `torchft/local_sgd.py` + the few helpers (`work.py::_DummyWork`,
type stubs for `Manager`) into our repo under BSD-3 attribution. ~700 LOC total.
### (5) Minimum viable test pattern for Spike 008
Goal: **2 replicas × 4 inner steps × 2 outer rounds on a tiny model**, single-process, no NCCL.
```python
# spikes/008-diloco-outer-loop/tests/test_diloco_two_replicas.py
"""
Spike 008: prove the DiLoCo outer-loop math is correct under our framework.
Runs entirely in a single process, no torch.distributed required.
"""
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from unittest.mock import create_autospec, MagicMock
from torchft.local_sgd import DiLoCo
from torchft.manager import Manager
from torchft.work import _DummyWork
class TinyMLP(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
def forward(self, x): return self.net(x)
def _make_avg_manager(replica_buffer):
"""Manager whose allreduce averages tensors across replicas via shared buffer."""
mgr = create_autospec(Manager)
mgr._use_async_quorum = False
mgr.errored.return_value = None
mgr.should_commit.return_value = True
mgr.current_step.return_value = 0
def avg_allreduce(tensor, should_quantize=False):
# Cross-replica average: stash and average against the other replica's tensor
replica_buffer.append(tensor.clone())
if len(replica_buffer) == 2:
mean = (replica_buffer[0] + replica_buffer[1]) / 2.0
tensor.copy_(mean)
replica_buffer.clear()
return _DummyWork(tensor)
mgr.allreduce.side_effect = avg_allreduce
return mgr
def test_diloco_two_replicas_four_inner_two_outer():
torch.manual_seed(0)
model_a = TinyMLP()
model_b = copy.deepcopy(model_a) # identical init = same θ_initial
# Inner optimizers (one per replica)
inner_a = optim.AdamW(model_a.parameters(), lr=1e-3)
inner_b = optim.AdamW(model_b.parameters(), lr=1e-3)
# Outer Nesterov (one per replica, same hyperparams)
outer_a = optim.SGD(model_a.parameters(), lr=0.7, momentum=0.9, nesterov=True)
outer_b = optim.SGD(model_b.parameters(), lr=0.7, momentum=0.9, nesterov=True)
# Shared buffer — both DiLoCo wrappers funnel through one "process group" of size 2
buf = []
mgr_a = _make_avg_manager(buf)
mgr_b = _make_avg_manager(buf)
SYNC_EVERY = 4 # 4 inner steps per outer round
OUTER_ROUNDS = 2
with DiLoCo(mgr_a, [model_a], inner_a, outer_a, sync_every=SYNC_EVERY) as dla, \
DiLoCo(mgr_b, [model_b], inner_b, outer_b, sync_every=SYNC_EVERY) as dlb:
# Snapshot θ_initial
theta_initial_a = {n: p.detach().clone() for n, p in model_a.named_parameters()}
for outer_round in range(OUTER_ROUNDS):
for inner_step in range(SYNC_EVERY):
# Replicas see DIFFERENT data — that is the whole point of DiLoCo
x_a = torch.randn(8, 4) + 0.1 * outer_round
x_b = torch.randn(8, 4) - 0.1 * outer_round
y_a, y_b = torch.randn(8, 2), torch.randn(8, 2)
inner_a.zero_grad(); inner_b.zero_grad()
((model_a(x_a) - y_a) ** 2).mean().backward()
((model_b(x_b) - y_b) ** 2).mean().backward()
inner_a.step() # Inner step. Sync fires automatically inside post-hook
inner_b.step() # at step %% SYNC_EVERY == 0.
# Assertions:
# 1. Both replicas now hold IDENTICAL parameters (they were averaged via mock allreduce).
for (na, pa), (nb, pb) in zip(model_a.named_parameters(), model_b.named_parameters()):
torch.testing.assert_close(pa, pb, msg=f"Replicas diverged at {na}")
# 2. Parameters changed from θ_initial (outer optimizer actually stepped).
any_change = any(
not torch.equal(p, theta_initial_a[n]) for n, p in model_a.named_parameters()
)
assert any_change, "outer optimizer did not move the parameters"
# 3. The outer optimizer holds Nesterov momentum state for every parameter
# (proves the SGD(nesterov=True) actually ran).
n_params = len(list(model_a.parameters()))
assert len(outer_a.state_dict()["state"]) == n_params
# 4. Sync fired once per outer round per replica.
assert mgr_a.start_quorum.call_count == OUTER_ROUNDS
assert mgr_b.start_quorum.call_count == OUTER_ROUNDS
```
**Why this works:**
- `DiLoCo` registers a post-step hook on `inner_optimizer` (see `__enter__`). The
hook increments `_local_step` and triggers `prepare_sync` / `perform_sync` on every
`sync_every` boundary — fully automatic, our test only calls `inner.step()`.
- `_DummyWork.wait()` is a no-op. `_average_grads` calls `manager.allreduce(...)`
which our `avg_allreduce` mocks to do real cross-replica averaging through `buf`.
- `manager.should_commit.return_value = True` lets the outer optimizer fire on each
outer round; setting it to `False` lets us also test rollback semantics.
- All single-process — pytest plays nicely. Add to
`spikes/005-integrated-trainer-skeleton/tests/` style or new `spikes/008/tests/`.
**Install for this spike:** `pip install torchft-nightly` in the eidolon venv. If the
nightly wheel proves brittle, fallback: vendor `local_sgd.py` + `work.py` + a
minimal `manager.py` stub (≈800 LOC) into `framework/diloco/_vendored/` with BSD-3
attribution.
---
## Risks & Mitigations
| Risk | Likelihood | Mitigation |
|---|---|---|
| `torchft-nightly` wheel breaks against torch 2.x | Med | Pin to a specific nightly hash; or vendor `local_sgd.py` directly under BSD-3. |
| `torchft.manager.Manager` import pulls in Rust ext at import time | Low | The class is importable as a type; `MagicMock` replaces it. If import touches Rust, we vendor. Verified: the import in `local_sgd.py` is `from torchft.manager import Manager` — only used as a type annotation in our test path. |
| Sign convention of pseudogradient causes our outer optimizer to move the wrong way | Med | Test 2 in the test pattern above explicitly checks "params moved from initial". A second test should compare the direction against a hand-computed expected. |
| `fragment_sync_delay > 0` (true Streaming) requires CUDA streams | Med | Spike 008 starts with `fragment_sync_delay=0` (= vanilla DiLoCo). Streaming variant deferred to Spike 009 once basic loop works. |
| Requires `torch>=2.7` per pyproject | Low | Framework already on torch 2.x; check exact pin. If <2.7, we vendor. |
---
## Decision (for ADR-003)
Adopt **`torchft.local_sgd.DiLoCo`** as the reference DiLoCo / Streaming DiLoCo
implementation. Integrate via `pip install torchft-nightly` for Spike 008. If
brittleness emerges, vendor `local_sgd.py` (BSD-3) into `framework/diloco/_vendored/`.
For the framework's outer-loop optimizer abstraction (the actual ADR-003 question):
mirror torchft's `DiLoCo(manager, [model_fragments], inner_opt, outer_opt, sync_every)`
constructor shape so that swapping our wrapper for the upstream class is a one-line
change. Compute pseudogradient as `θ_local − θ_initial` (our convention) and negate
when handing to the outer optimizer, OR follow torchft's `θ_initial − θ_local`
convention end-to-end. **Pick one and document it loudly.**