|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
import uuid |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import torch.distributed.elastic.rendezvous.registry as rdzv_registry |
|
|
from torch.distributed.elastic import events, metrics |
|
|
from torch.distributed.elastic.agent.server.api import WorkerSpec |
|
|
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent |
|
|
from torch.distributed.elastic.multiprocessing import SignalException, Std |
|
|
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError |
|
|
from torch.distributed.elastic.rendezvous import RendezvousParameters |
|
|
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint |
|
|
from torch.distributed.elastic.utils.logging import get_logger |
|
|
|
|
|
__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent'] |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LaunchConfig: |
|
|
""" |
|
|
Creates a rendezvous config. |
|
|
|
|
|
Args: |
|
|
min_nodes: Minimum amount of nodes that the user function will |
|
|
be launched on. Elastic agent ensures that the user |
|
|
function start only when the min_nodes amount enters |
|
|
the rendezvous. |
|
|
max_nodes: Maximum amount of nodes that the user function |
|
|
will be launched on. |
|
|
nproc_per_node: On each node the elastic agent will launch |
|
|
this amount of workers that will execute user |
|
|
defined function. |
|
|
rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd). |
|
|
rdzv_endpoint: The endpoint of the rdzv sync. storage. |
|
|
rdzv_configs: Key, value pair that specifies rendezvous specific configuration. |
|
|
rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going |
|
|
to be removed in future versions, see the note below. The default timeout is 900 seconds. |
|
|
run_id: The unique run id of the job (if not passed a unique one will be |
|
|
deduced from run environment - flow workflow id in flow - or auto generated). |
|
|
role: User defined role of the worker (defaults to "trainer"). |
|
|
max_restarts: The maximum amount of restarts that elastic agent will conduct |
|
|
on workers before failure. |
|
|
monitor_interval: The interval in seconds that is used by the elastic_agent |
|
|
as a period of monitoring workers. |
|
|
start_method: The method is used by the elastic agent to start the |
|
|
workers (spawn, fork, forkserver). |
|
|
log_dir: base log directory where log files are written. If not set, |
|
|
one is created in a tmp dir but NOT removed on exit. |
|
|
redirects: configuration to redirect stdout/stderr to log files. |
|
|
Pass a single ``Std`` enum to redirect all workers, |
|
|
or a mapping keyed by local_rank to selectively redirect. |
|
|
tee: configuration to "tee" stdout/stderr to console + log file. |
|
|
metrics_cfg: configuration to initialize metrics. |
|
|
|
|
|
..note: |
|
|
`rdzv_timeout` is a legacy argument that will be removed in future. |
|
|
Set the timeout via `rdzv_configs['timeout']` |
|
|
|
|
|
""" |
|
|
|
|
|
min_nodes: int |
|
|
max_nodes: int |
|
|
nproc_per_node: int |
|
|
run_id: str = "" |
|
|
role: str = "default_role" |
|
|
rdzv_endpoint: str = "" |
|
|
rdzv_backend: str = "etcd" |
|
|
rdzv_configs: Dict[str, Any] = field(default_factory=dict) |
|
|
rdzv_timeout: int = -1 |
|
|
max_restarts: int = 3 |
|
|
monitor_interval: float = 30 |
|
|
start_method: str = "spawn" |
|
|
log_dir: Optional[str] = None |
|
|
redirects: Union[Std, Dict[int, Std]] = Std.NONE |
|
|
tee: Union[Std, Dict[int, Std]] = Std.NONE |
|
|
metrics_cfg: Dict[str, str] = field(default_factory=dict) |
|
|
|
|
|
def __post_init__(self): |
|
|
default_timeout = 900 |
|
|
if self.rdzv_timeout != -1: |
|
|
self.rdzv_configs["timeout"] = self.rdzv_timeout |
|
|
elif "timeout" not in self.rdzv_configs: |
|
|
self.rdzv_configs["timeout"] = default_timeout |
|
|
|
|
|
|
|
|
class elastic_launch: |
|
|
""" |
|
|
Launches an torchelastic agent on the container that invoked the entrypoint. |
|
|
|
|
|
1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/ |
|
|
``entrypoint`` can be a function or a command. |
|
|
2. The return value is a map of each worker's output mapped |
|
|
by their respective global rank. |
|
|
|
|
|
Usage |
|
|
|
|
|
:: |
|
|
|
|
|
def worker_fn(foo): |
|
|
# ... |
|
|
|
|
|
def main(): |
|
|
# entrypoint is a function. |
|
|
outputs = elastic_launch(LaunchConfig, worker_fn)(foo) |
|
|
# return rank 0's output |
|
|
return outputs[0] |
|
|
|
|
|
# entrypoint is a command and ``script.py`` is the python module. |
|
|
ouptuts = elestic_launch(LaunchConfig, "script.py")(args) |
|
|
ouptuts = elestic_launch(LaunchConfig, "python")("script.py") |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: LaunchConfig, |
|
|
entrypoint: Union[Callable, str, None], |
|
|
): |
|
|
self._config = config |
|
|
self._entrypoint = entrypoint |
|
|
|
|
|
def __call__(self, *args): |
|
|
return launch_agent(self._config, self._entrypoint, list(args)) |
|
|
|
|
|
|
|
|
def _get_entrypoint_name( |
|
|
entrypoint: Union[Callable, str, None], args: List[Any] |
|
|
) -> str: |
|
|
"""Retrive entrypoint name with the rule: |
|
|
1. If entrypoint is a function, use ``entrypont.__qualname__``. |
|
|
2. If entrypoint is a string, check its value: |
|
|
2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args`` |
|
|
which does not start with hifen letter (for example, "-u" will be skipped). |
|
|
2.2 otherwise, use ``entrypoint`` value. |
|
|
3. Otherwise, return empty string. |
|
|
""" |
|
|
if isinstance(entrypoint, Callable): |
|
|
return entrypoint.__name__ |
|
|
elif isinstance(entrypoint, str): |
|
|
if entrypoint == sys.executable: |
|
|
return next((arg for arg in args if arg[0] != "-"), "") |
|
|
else: |
|
|
return entrypoint |
|
|
else: |
|
|
return "" |
|
|
|
|
|
|
|
|
def _get_addr_and_port( |
|
|
rdzv_parameters: RendezvousParameters, |
|
|
) -> Tuple[Optional[str], Optional[int]]: |
|
|
if rdzv_parameters.backend != "static": |
|
|
return (None, None) |
|
|
endpoint = rdzv_parameters.endpoint |
|
|
endpoint = endpoint.strip() |
|
|
if not endpoint: |
|
|
raise ValueError( |
|
|
"Endpoint is missing in endpoint. Try to add --master_addr and --master_port" |
|
|
) |
|
|
master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) |
|
|
if master_port == -1: |
|
|
raise ValueError( |
|
|
f"port is missing in endpoint: {endpoint}. Try to specify --master_port" |
|
|
) |
|
|
return (master_addr, master_port) |
|
|
|
|
|
|
|
|
def launch_agent( |
|
|
config: LaunchConfig, |
|
|
entrypoint: Union[Callable, str, None], |
|
|
args: List[Any], |
|
|
) -> Dict[int, Any]: |
|
|
if not config.run_id: |
|
|
run_id = str(uuid.uuid4().int) |
|
|
logger.warning(f"config has no run_id, generated a random run_id: {run_id}") |
|
|
config.run_id = run_id |
|
|
|
|
|
entrypoint_name = _get_entrypoint_name(entrypoint, args) |
|
|
|
|
|
logger.info( |
|
|
f"Starting elastic_operator with launch configs:\n" |
|
|
f" entrypoint : {entrypoint_name}\n" |
|
|
f" min_nodes : {config.min_nodes}\n" |
|
|
f" max_nodes : {config.max_nodes}\n" |
|
|
f" nproc_per_node : {config.nproc_per_node}\n" |
|
|
f" run_id : {config.run_id}\n" |
|
|
f" rdzv_backend : {config.rdzv_backend}\n" |
|
|
f" rdzv_endpoint : {config.rdzv_endpoint}\n" |
|
|
f" rdzv_configs : {config.rdzv_configs}\n" |
|
|
f" max_restarts : {config.max_restarts}\n" |
|
|
f" monitor_interval : {config.monitor_interval}\n" |
|
|
f" log_dir : {config.log_dir}\n" |
|
|
f" metrics_cfg : {config.metrics_cfg}\n" |
|
|
) |
|
|
|
|
|
rdzv_parameters = RendezvousParameters( |
|
|
backend=config.rdzv_backend, |
|
|
endpoint=config.rdzv_endpoint, |
|
|
run_id=config.run_id, |
|
|
min_nodes=config.min_nodes, |
|
|
max_nodes=config.max_nodes, |
|
|
**config.rdzv_configs, |
|
|
) |
|
|
|
|
|
master_addr, master_port = _get_addr_and_port(rdzv_parameters) |
|
|
|
|
|
spec = WorkerSpec( |
|
|
role=config.role, |
|
|
local_world_size=config.nproc_per_node, |
|
|
entrypoint=entrypoint, |
|
|
args=tuple(args), |
|
|
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), |
|
|
max_restarts=config.max_restarts, |
|
|
monitor_interval=config.monitor_interval, |
|
|
redirects=config.redirects, |
|
|
tee=config.tee, |
|
|
master_addr=master_addr, |
|
|
master_port=master_port, |
|
|
) |
|
|
|
|
|
agent = LocalElasticAgent( |
|
|
spec=spec, start_method=config.start_method, log_dir=config.log_dir |
|
|
) |
|
|
|
|
|
shutdown_rdzv = True |
|
|
try: |
|
|
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) |
|
|
|
|
|
result = agent.run() |
|
|
|
|
|
events.record(agent.get_event_succeeded()) |
|
|
|
|
|
if result.is_failed(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raise ChildFailedError( |
|
|
name=entrypoint_name, |
|
|
failures=result.failures, |
|
|
) |
|
|
|
|
|
return result.return_values |
|
|
except ChildFailedError: |
|
|
raise |
|
|
except SignalException: |
|
|
|
|
|
|
|
|
|
|
|
shutdown_rdzv = False |
|
|
events.record(agent.get_event_failed()) |
|
|
raise |
|
|
except Exception: |
|
|
events.record(agent.get_event_failed()) |
|
|
raise |
|
|
finally: |
|
|
if shutdown_rdzv: |
|
|
spec.rdzv_handler.shutdown() |
|
|
|