Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 5,202 Bytes
b266c31 bd0c358 b266c31 bd0c358 b266c31 bd0c358 b266c31 bd0c358 b266c31 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """Replica entrypoint — what each serverless replica runs.
This is the script invoked by `LocalProcessExecutor`, `ModalExecutor`,
`HFJobsExecutor`, etc. It learns its rank from the `REPLICA_RANK` env
var, sets up `ObjectStoreAllReduce` against the shared rendezvous URI,
wraps it in a `MockManager`, and hands it off to the user's training
function.
Usage from an executor:
>>> executor.launch_replicas(
... n_replicas=4,
... entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
... entrypoint_args={
... "rendezvous_uri": "/tmp/run42/",
... "world_size": 4,
... "trainer_module": "my_project.trainer",
... "trainer_fn": "train",
... "trainer_kwargs": {"model_name": "Qwen/Qwen2.5-0.5B"},
... },
... )
The entrypoint expects:
- `REPLICA_RANK` env var set to the rank (0..world_size-1)
- `rendezvous_uri`: fsspec URI for object-store rendezvous
- `world_size`: total replicas
- `trainer_module`, `trainer_fn`: importable path to the user's train fn
- `trainer_kwargs`: dict passed to the user's train fn, plus an injected
`manager` kwarg containing the `MockManager`
"""
from __future__ import annotations
import importlib
import os
from typing import Any
def main(
rendezvous_uri: str,
world_size: int,
trainer_module: str,
trainer_fn: str = "train",
trainer_kwargs: dict[str, Any] | None = None,
) -> Any:
"""Entrypoint executed inside each replica.
Args:
rendezvous_uri: fsspec URI (or local path) for the rendezvous
world_size: total replicas
trainer_module: importable Python module containing the user's
train function
trainer_fn: name of the function to call (default "train")
trainer_kwargs: kwargs passed to the train function
Returns:
Whatever the train function returns.
"""
from composer_replication.diloco.serverless.allreduce import (
MockManager,
ObjectStoreAllReduce,
)
rank_str = os.environ.get("REPLICA_RANK")
if rank_str is None:
raise RuntimeError(
"REPLICA_RANK env var not set. The serverless executor "
"should set this for each replica."
)
rank = int(rank_str)
if not (0 <= rank < world_size):
raise ValueError(f"REPLICA_RANK={rank} not in [0, {world_size})")
store = ObjectStoreAllReduce(
uri=rendezvous_uri,
rank=rank,
world_size=world_size,
)
manager = MockManager(store)
mod = importlib.import_module(trainer_module)
fn = getattr(mod, trainer_fn)
kwargs = dict(trainer_kwargs or {})
kwargs["manager"] = manager # injected
kwargs["rank"] = rank
kwargs["world_size"] = world_size
return fn(**kwargs)
if __name__ == "__main__":
import argparse
import json
# Dual input contract (both backends supported):
# * argv — SageMakerExecutor / LocalProcessExecutor pass the run config as
# `--rendezvous/--world-size/--trainer-module` ContainerArguments.
# * env — EKSExecutor (and any backend that prefers a pure-env contract,
# since k8s Indexed Jobs already inject REPLICA_RANK via the downward API)
# pass the SAME values as RENDEZVOUS_URI / WORLD_SIZE / TRAINER_MODULE
# env vars. The argv flags are therefore NOT `required=True`: when absent
# we fall back to the env vars, and only error if NEITHER source supplies
# a mandatory field. This is the R3 fix — previously the argparse block
# hard-required argv, so an EKS pod (env-only) crashed at arg-parsing.
parser = argparse.ArgumentParser()
parser.add_argument("--rendezvous", default=None)
parser.add_argument("--world-size", type=int, default=None)
parser.add_argument("--trainer-module", default=None)
parser.add_argument("--trainer-fn", default=None)
parser.add_argument("--trainer-kwargs-json", default=None)
args = parser.parse_args()
def _resolve(arg_val, env_key, *, required, cast=lambda x: x):
if arg_val is not None:
return arg_val
env_val = os.environ.get(env_key)
if env_val is not None:
return cast(env_val)
if required:
raise SystemExit(
f"replica_entrypoint: missing '{env_key}' — supply it via the "
f"argv flag or the {env_key} environment variable "
f"(EKSExecutor uses env; SageMaker/Local use argv)."
)
return None
rendezvous = _resolve(args.rendezvous, "RENDEZVOUS_URI", required=True)
world_size = _resolve(args.world_size, "WORLD_SIZE", required=True, cast=int)
trainer_module = _resolve(args.trainer_module, "TRAINER_MODULE", required=True)
trainer_fn = _resolve(args.trainer_fn, "TRAINER_FN", required=False) or "train"
kwargs_json = _resolve(
args.trainer_kwargs_json, "TRAINER_KWARGS_JSON", required=False
) or "{}"
main(
rendezvous_uri=rendezvous,
world_size=world_size,
trainer_module=trainer_module,
trainer_fn=trainer_fn,
trainer_kwargs=json.loads(kwargs_json),
)
|