File size: 5,202 Bytes
b266c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd0c358
 
 
 
 
 
 
 
 
 
b266c31
bd0c358
 
 
 
 
b266c31
 
bd0c358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b266c31
bd0c358
 
 
 
 
b266c31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Replica entrypoint — what each serverless replica runs.

This is the script invoked by `LocalProcessExecutor`, `ModalExecutor`,
`HFJobsExecutor`, etc. It learns its rank from the `REPLICA_RANK` env
var, sets up `ObjectStoreAllReduce` against the shared rendezvous URI,
wraps it in a `MockManager`, and hands it off to the user's training
function.

Usage from an executor:

    >>> executor.launch_replicas(
    ...     n_replicas=4,
    ...     entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
    ...     entrypoint_args={
    ...         "rendezvous_uri": "/tmp/run42/",
    ...         "world_size": 4,
    ...         "trainer_module": "my_project.trainer",
    ...         "trainer_fn": "train",
    ...         "trainer_kwargs": {"model_name": "Qwen/Qwen2.5-0.5B"},
    ...     },
    ... )

The entrypoint expects:
- `REPLICA_RANK` env var set to the rank (0..world_size-1)
- `rendezvous_uri`: fsspec URI for object-store rendezvous
- `world_size`: total replicas
- `trainer_module`, `trainer_fn`: importable path to the user's train fn
- `trainer_kwargs`: dict passed to the user's train fn, plus an injected
  `manager` kwarg containing the `MockManager`
"""
from __future__ import annotations

import importlib
import os
from typing import Any


def main(
    rendezvous_uri: str,
    world_size: int,
    trainer_module: str,
    trainer_fn: str = "train",
    trainer_kwargs: dict[str, Any] | None = None,
) -> Any:
    """Entrypoint executed inside each replica.

    Args:
        rendezvous_uri: fsspec URI (or local path) for the rendezvous
        world_size: total replicas
        trainer_module: importable Python module containing the user's
            train function
        trainer_fn: name of the function to call (default "train")
        trainer_kwargs: kwargs passed to the train function

    Returns:
        Whatever the train function returns.
    """
    from composer_replication.diloco.serverless.allreduce import (
        MockManager,
        ObjectStoreAllReduce,
    )

    rank_str = os.environ.get("REPLICA_RANK")
    if rank_str is None:
        raise RuntimeError(
            "REPLICA_RANK env var not set. The serverless executor "
            "should set this for each replica."
        )
    rank = int(rank_str)

    if not (0 <= rank < world_size):
        raise ValueError(f"REPLICA_RANK={rank} not in [0, {world_size})")

    store = ObjectStoreAllReduce(
        uri=rendezvous_uri,
        rank=rank,
        world_size=world_size,
    )
    manager = MockManager(store)

    mod = importlib.import_module(trainer_module)
    fn = getattr(mod, trainer_fn)

    kwargs = dict(trainer_kwargs or {})
    kwargs["manager"] = manager  # injected
    kwargs["rank"] = rank
    kwargs["world_size"] = world_size
    return fn(**kwargs)


if __name__ == "__main__":
    import argparse
    import json

    # Dual input contract (both backends supported):
    #   * argv  — SageMakerExecutor / LocalProcessExecutor pass the run config as
    #     `--rendezvous/--world-size/--trainer-module` ContainerArguments.
    #   * env   — EKSExecutor (and any backend that prefers a pure-env contract,
    #     since k8s Indexed Jobs already inject REPLICA_RANK via the downward API)
    #     pass the SAME values as RENDEZVOUS_URI / WORLD_SIZE / TRAINER_MODULE
    #     env vars. The argv flags are therefore NOT `required=True`: when absent
    #     we fall back to the env vars, and only error if NEITHER source supplies
    #     a mandatory field. This is the R3 fix — previously the argparse block
    #     hard-required argv, so an EKS pod (env-only) crashed at arg-parsing.
    parser = argparse.ArgumentParser()
    parser.add_argument("--rendezvous", default=None)
    parser.add_argument("--world-size", type=int, default=None)
    parser.add_argument("--trainer-module", default=None)
    parser.add_argument("--trainer-fn", default=None)
    parser.add_argument("--trainer-kwargs-json", default=None)
    args = parser.parse_args()

    def _resolve(arg_val, env_key, *, required, cast=lambda x: x):
        if arg_val is not None:
            return arg_val
        env_val = os.environ.get(env_key)
        if env_val is not None:
            return cast(env_val)
        if required:
            raise SystemExit(
                f"replica_entrypoint: missing '{env_key}' — supply it via the "
                f"argv flag or the {env_key} environment variable "
                f"(EKSExecutor uses env; SageMaker/Local use argv)."
            )
        return None

    rendezvous = _resolve(args.rendezvous, "RENDEZVOUS_URI", required=True)
    world_size = _resolve(args.world_size, "WORLD_SIZE", required=True, cast=int)
    trainer_module = _resolve(args.trainer_module, "TRAINER_MODULE", required=True)
    trainer_fn = _resolve(args.trainer_fn, "TRAINER_FN", required=False) or "train"
    kwargs_json = _resolve(
        args.trainer_kwargs_json, "TRAINER_KWARGS_JSON", required=False
    ) or "{}"

    main(
        rendezvous_uri=rendezvous,
        world_size=world_size,
        trainer_module=trainer_module,
        trainer_fn=trainer_fn,
        trainer_kwargs=json.loads(kwargs_json),
    )