composer-replication-framework / docs /research /DILOCO_RECONNAISSANCE.md
Codeseys's picture
Wave 7: Phase 2-4 of deep work loop — backlog, parallel research, three ADRs
ac4bfb4
|
Raw
History Blame Contribute Delete
17.9 kB

DiLoCo Reference Implementation Reconnaissance

Date: 2026-05-25 Purpose: Pick ONE PyTorch reference implementation of (Streaming) DiLoCo to bolt onto the composer-replication-framework outer-loop optimizer. Feeds ADR-003.

Bias: simple + working > fancy + theoretically-better. Library > research codebase.


TL;DR — Recommendation

Use meta-pytorch/torchft's torchft.local_sgd.DiLoCo context manager.

It is a maintained library (not a research codebase), BSD-3 licensed, supports both vanilla DiLoCo and Streaming DiLoCo through one class, and — critically — is unit-testable in a single process by passing a MagicMock(Manager) whose allreduce returns a _DummyWork. Their own torchft/local_sgd_test.py already demonstrates the exact pattern Spike 008 needs.

The Streaming DiLoCo paper (Liu et al. 2025, arXiv:2501.18512) has no separate community implementation — torchft is the reference implementation as of mid-2026. PrimeIntellect's two repos are either too minimal (diloco_simple, no LICENSE, NCCL-locked, no Streaming) or deprecated (OpenDiLoCo, hivemind-based, "no longer maintained" per its own README).


Candidates Audited (primary sources only)

A1. PrimeIntellect-ai/diloco_simple

  • URL: https://github.com/PrimeIntellect-ai/diloco_simple
  • License: NONE (no LICENSE file in repo — confirmed via git clone + ls). All-rights-reserved by default under copyright law. Cannot legally vendor or fork.
  • Last commit: 2024-05-31 (be38ec4 add weight decay).
  • Activity: 8 commits total, ever. Two main authors. Effectively abandoned.
  • Shape: single 180-LOC research script (pure_torch_diloco.py), pedagogical demo.
  • Streaming DiLoCo? No. Vanilla DiLoCo only.
  • Distributed: Hard-coded NCCL via torchrun + init_process_group(backend="nccl"). Pulls in wandb, transformers, HuggingFace datasets, cyclopts, and trains a full LlamaForCausalLM on C4. Not a library — a benchmark script.
  • Verdict: REJECT. No license, no Streaming, no library API, NCCL-only, deps on HF + wandb just to run. Useful as an algorithm reference, not as code to depend on.

A2. PrimeIntellect-ai/OpenDiLoCo

  • URL: https://github.com/PrimeIntellect-ai/OpenDiloco
  • License: present (Apache-2.0 typical, not re-verified — moot, see below).
  • Status: Officially deprecated. README first paragraph:

    "Important Notice: OpenDiLoCo is no longer maintained. For our production-ready distributed training solution, please check out prime."

  • Built on: hivemind (DHT-based decentralized training). Multi-machine only.
  • Streaming DiLoCo? No.
  • Verdict: REJECT. Deprecated by its authors. Hivemind dependency would force us to set up DHT initial peers just to run a unit test.

A3. PrimeIntellect-ai/prime (a.k.a. INTELLECT-1 framework)

  • URL: https://github.com/PrimeIntellect-ai/prime — note: the GitHub org now uses this repo for their CLI/SDK; the original training framework was rebranded.
  • The actual INTELLECT-1 training code uses an ElasticDeviceMesh abstraction and is a full distributed training stack, not an algorithm library.
  • Verdict: REJECT. Production framework, not a drop-in library. Coupling a 1.5k-LOC fault-tolerant elastic mesh into our test framework is the opposite of "simple + working".

A4. DeepMind reference implementation (Douillard et al., arXiv:2311.08105)

  • No public reference implementation exists. The DiLoCo paper is algorithm-only. Confirmed: paper has no associated GitHub link in arXiv abstract or PDF; HuggingFace papers page links no code. DeepMind has not open-sourced their internal trainer.
  • Verdict: N/A — does not exist.

A5. meta-pytorch/torchft ← CHOSEN

  • URL: https://github.com/meta-pytorch/torchft
  • License: BSD 3-Clause (verified: head -5 LICENSE → "BSD 3-Clause License").
  • Last commit on main: 2026-04-03 (HEAD 7eb7087 Add torchcomms ProcessGroup shim for fault-tolerant reconfiguration).
  • Activity: 312 commits, multiple Meta contributors, recent commits across 2025 and 2026, active CI, nightly PyPI builds at https://pypi.org/project/torchft-nightly/.
  • Shape: library, not a research codebase. torchft/ is a proper Python package with local_sgd.py, manager.py, process_group.py, local_sgd_test.py (real pytest unit tests), pyproject.toml, BSD-3.
  • Streaming DiLoCo? Yes — the DiLoCo class is itself a Streaming DiLoCo generalization (fragment_sync_delay, fragment_update_alpha); pass a single-element model_fragments=[model] for vanilla DiLoCo.
  • Source comment confirms: """... DiLoCo paper: https://arxiv.org/pdf/2311.08105 / Streaming DiLoCo paper: https://arxiv.org/pdf/2501.18512 """

Deep Dive: torchft (the chosen one)

(1) Repo metadata

Field Value
URL https://github.com/meta-pytorch/torchft
License BSD 3-Clause
HEAD commit 7eb7087 (2026-04-03)
Total commits on main 312
Activity level Active — commits in 2025 + 2026, Meta-maintained, PyPI nightly builds
Distribution pip install torchft-nightly (prebuilt wheels) OR install from source (requires Rust + protobuf-compiler + maturin — only because of the Lighthouse/process-group Rust ext, not the algorithm code)
Python requires-python = ">=3.8"; torch>=2.7 per pyproject.toml

(2) Exact API / extension point

The integration target is torchft/local_sgd.py. Two relevant classes:

# Public class — drop-in context manager
class DiLoCo:
    def __init__(
        self,
        manager: Manager,                                  # we mock this
        model_fragments: List[nn.Module],                  # [model] for vanilla DiLoCo
        inner_optimizer: optim.Optimizer,
        outer_optimizer: optim.Optimizer | list[optim.Optimizer],
        sync_every: int,                                   # N inner steps
        backup_device: Optional[torch.device] = None,
        pin_memory: bool = True,
        use_bucketization: bool = False,
        bucket_cap_mb: Optional[int] = None,
        should_quantize: bool = False,
        fragment_sync_delay: int = 0,                      # τ in Streaming DiLoCo paper
        fragment_update_alpha: float = 0.0,
    ) -> None: ...

The pseudo-gradient is computed in _StreamingDiLoCoFragment._save_grads() (torchft/local_sgd.py line 324):

def _save_grads(self) -> None:
    """Saves pseudo-gradients of the parameters"""
    with torch.no_grad():
        for name, p in self._model_fragment.named_parameters():
            local_param = p.to_local() if isinstance(p, DTensor) else p
            pseudogradient = self.original_parameters[name].to(p.device) - local_param
            self._grads[name] = pseudogradient

Note the sign: original − local (i.e. θ_initial − θ_local). When this is later copied into p.grad via _set_grads, an SGD step p ← p − lr · grad becomes p ← θ_initial − lr · (θ_initial − θ_local) = a step toward θ_local. Our spec says δ = θ_local − θ_initial; torchft uses the negation. Either convention works as long as the outer optimizer's lr sign is consistent — torchft uses positive outer_lr (e.g. 0.7) and SGD which subtracts the grad, so the math nets out. Be careful when unit-testing the sign in Spike 008.

The outer Nesterov step is in _StreamingDiLoCoFragment.perform_sync() (line 423):

if should_commit:
    self._set_grads()                  # write pseudogradient into p.grad
    self._outer_optimizer.step()       # Nesterov SGD step (user-provided)
    self.save_parameters()
    self._merge_parameters()
self._outer_optimizer.zero_grad()

The Nesterov-ness lives in the user-provided outer optimizer, e.g.:

outer_optimizer = torch.optim.SGD(model.parameters(), lr=0.7, momentum=0.9, nesterov=True)

This matches the DiLoCo paper exactly (Douillard §3 specifies Nesterov momentum outer).

The cross-replica all-reduce happens in _average_grads() (called from prepare_sync) via self._manager.allreduce(...) — which is the seam we mock for single-process tests.

(3) torch.distributed dependency for testing?

No, not for unit tests. The Manager is mockable. From torchft/local_sgd_test.py:

from unittest.mock import create_autospec, MagicMock
from torchft.manager import Manager
from torchft.work import _DummyWork

def create_manager() -> MagicMock:
    manager = create_autospec(Manager)
    manager.errored.return_value = None
    def mock_allreduce(tensor: torch.Tensor, should_quantize: bool = False):
        return _DummyWork(tensor)        # returns the same tensor unchanged
    manager.allreduce.side_effect = mock_allreduce
    return manager

This bypasses NCCL/Gloo entirely. _DummyWork just wraps the tensor and returns it as the "all-reduced" result, so a single-process test with world_size=1 works directly, and a 2-replica test is achieved by running two DiLoCo instances with two model copies in the same process and a mock_allreduce that averages the two tensors manually before returning. (Their test_bucketization_correctness does exactly this.)

For real distributed runs torchft uses Gloo or NCCL via torchft.process_group (reconfigurable PGs that wrap torch.distributed). We do not need this for Spike 008.

(4) Library, research codebase, or paper-companion?

Library. Strong evidence:

  • Proper Python package layout (torchft/__init__.py, modules per concern).
  • Real unit tests (*_test.py per module) — not "run this script" demos.
  • BSD-3-Clause LICENSE (vs. diloco_simple having none, signaling "personal demo").
  • Nightly PyPI distribution (torchft-nightly) with prebuilt wheels.
  • Documentation site at https://pytorch.org/torchft.
  • meta-pytorch org — Meta-internally maintained; lives next to torchtitan.
  • README explicitly: "torchft is designed to provide the primitives required to implement fault tolerance in any application/train script" — i.e. a building block.

Only friction: installing from source needs Rust (pyo3 + maturin) and protobuf-compiler. This is for the Rust Lighthouse/process-group extension which we do not need for Spike 008's mock-based tests. Two clean options:

  • (a) pip install torchft-nightly — uses prebuilt wheel, no Rust toolchain needed.
  • (b) Vendor torchft/local_sgd.py + the few helpers (work.py::_DummyWork, type stubs for Manager) into our repo under BSD-3 attribution. ~700 LOC total.

(5) Minimum viable test pattern for Spike 008

Goal: 2 replicas × 4 inner steps × 2 outer rounds on a tiny model, single-process, no NCCL.

# spikes/008-diloco-outer-loop/tests/test_diloco_two_replicas.py
"""
Spike 008: prove the DiLoCo outer-loop math is correct under our framework.
Runs entirely in a single process, no torch.distributed required.
"""
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from unittest.mock import create_autospec, MagicMock

from torchft.local_sgd import DiLoCo
from torchft.manager import Manager
from torchft.work import _DummyWork


class TinyMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
    def forward(self, x): return self.net(x)


def _make_avg_manager(replica_buffer):
    """Manager whose allreduce averages tensors across replicas via shared buffer."""
    mgr = create_autospec(Manager)
    mgr._use_async_quorum = False
    mgr.errored.return_value = None
    mgr.should_commit.return_value = True
    mgr.current_step.return_value = 0
    def avg_allreduce(tensor, should_quantize=False):
        # Cross-replica average: stash and average against the other replica's tensor
        replica_buffer.append(tensor.clone())
        if len(replica_buffer) == 2:
            mean = (replica_buffer[0] + replica_buffer[1]) / 2.0
            tensor.copy_(mean)
            replica_buffer.clear()
        return _DummyWork(tensor)
    mgr.allreduce.side_effect = avg_allreduce
    return mgr


def test_diloco_two_replicas_four_inner_two_outer():
    torch.manual_seed(0)
    model_a = TinyMLP()
    model_b = copy.deepcopy(model_a)              # identical init = same θ_initial

    # Inner optimizers (one per replica)
    inner_a = optim.AdamW(model_a.parameters(), lr=1e-3)
    inner_b = optim.AdamW(model_b.parameters(), lr=1e-3)
    # Outer Nesterov (one per replica, same hyperparams)
    outer_a = optim.SGD(model_a.parameters(), lr=0.7, momentum=0.9, nesterov=True)
    outer_b = optim.SGD(model_b.parameters(), lr=0.7, momentum=0.9, nesterov=True)

    # Shared buffer — both DiLoCo wrappers funnel through one "process group" of size 2
    buf = []
    mgr_a = _make_avg_manager(buf)
    mgr_b = _make_avg_manager(buf)

    SYNC_EVERY = 4   # 4 inner steps per outer round
    OUTER_ROUNDS = 2

    with DiLoCo(mgr_a, [model_a], inner_a, outer_a, sync_every=SYNC_EVERY) as dla, \
         DiLoCo(mgr_b, [model_b], inner_b, outer_b, sync_every=SYNC_EVERY) as dlb:
        # Snapshot θ_initial
        theta_initial_a = {n: p.detach().clone() for n, p in model_a.named_parameters()}

        for outer_round in range(OUTER_ROUNDS):
            for inner_step in range(SYNC_EVERY):
                # Replicas see DIFFERENT data — that is the whole point of DiLoCo
                x_a = torch.randn(8, 4) + 0.1 * outer_round
                x_b = torch.randn(8, 4) - 0.1 * outer_round
                y_a, y_b = torch.randn(8, 2), torch.randn(8, 2)

                inner_a.zero_grad(); inner_b.zero_grad()
                ((model_a(x_a) - y_a) ** 2).mean().backward()
                ((model_b(x_b) - y_b) ** 2).mean().backward()
                inner_a.step()   # Inner step. Sync fires automatically inside post-hook
                inner_b.step()   # at step %% SYNC_EVERY == 0.

        # Assertions:
        # 1. Both replicas now hold IDENTICAL parameters (they were averaged via mock allreduce).
        for (na, pa), (nb, pb) in zip(model_a.named_parameters(), model_b.named_parameters()):
            torch.testing.assert_close(pa, pb, msg=f"Replicas diverged at {na}")

        # 2. Parameters changed from θ_initial (outer optimizer actually stepped).
        any_change = any(
            not torch.equal(p, theta_initial_a[n]) for n, p in model_a.named_parameters()
        )
        assert any_change, "outer optimizer did not move the parameters"

        # 3. The outer optimizer holds Nesterov momentum state for every parameter
        #    (proves the SGD(nesterov=True) actually ran).
        n_params = len(list(model_a.parameters()))
        assert len(outer_a.state_dict()["state"]) == n_params

        # 4. Sync fired once per outer round per replica.
        assert mgr_a.start_quorum.call_count == OUTER_ROUNDS
        assert mgr_b.start_quorum.call_count == OUTER_ROUNDS

Why this works:

  • DiLoCo registers a post-step hook on inner_optimizer (see __enter__). The hook increments _local_step and triggers prepare_sync / perform_sync on every sync_every boundary — fully automatic, our test only calls inner.step().
  • _DummyWork.wait() is a no-op. _average_grads calls manager.allreduce(...) which our avg_allreduce mocks to do real cross-replica averaging through buf.
  • manager.should_commit.return_value = True lets the outer optimizer fire on each outer round; setting it to False lets us also test rollback semantics.
  • All single-process — pytest plays nicely. Add to spikes/005-integrated-trainer-skeleton/tests/ style or new spikes/008/tests/.

Install for this spike: pip install torchft-nightly in the eidolon venv. If the nightly wheel proves brittle, fallback: vendor local_sgd.py + work.py + a minimal manager.py stub (≈800 LOC) into framework/diloco/_vendored/ with BSD-3 attribution.


Risks & Mitigations

Risk Likelihood Mitigation
torchft-nightly wheel breaks against torch 2.x Med Pin to a specific nightly hash; or vendor local_sgd.py directly under BSD-3.
torchft.manager.Manager import pulls in Rust ext at import time Low The class is importable as a type; MagicMock replaces it. If import touches Rust, we vendor. Verified: the import in local_sgd.py is from torchft.manager import Manager — only used as a type annotation in our test path.
Sign convention of pseudogradient causes our outer optimizer to move the wrong way Med Test 2 in the test pattern above explicitly checks "params moved from initial". A second test should compare the direction against a hand-computed expected.
fragment_sync_delay > 0 (true Streaming) requires CUDA streams Med Spike 008 starts with fragment_sync_delay=0 (= vanilla DiLoCo). Streaming variant deferred to Spike 009 once basic loop works.
Requires torch>=2.7 per pyproject Low Framework already on torch 2.x; check exact pin. If <2.7, we vendor.

Decision (for ADR-003)

Adopt torchft.local_sgd.DiLoCo as the reference DiLoCo / Streaming DiLoCo implementation. Integrate via pip install torchft-nightly for Spike 008. If brittleness emerges, vendor local_sgd.py (BSD-3) into framework/diloco/_vendored/.

For the framework's outer-loop optimizer abstraction (the actual ADR-003 question): mirror torchft's DiLoCo(manager, [model_fragments], inner_opt, outer_opt, sync_every) constructor shape so that swapping our wrapper for the upstream class is a one-line change. Compute pseudogradient as θ_local − θ_initial (our convention) and negate when handing to the outer optimizer, OR follow torchft's θ_initial − θ_local convention end-to-end. Pick one and document it loudly.