|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import os |
|
|
import shutil |
|
|
import signal |
|
|
import socket |
|
|
import tempfile |
|
|
import uuid |
|
|
from typing import Any, Dict, Optional, Tuple |
|
|
|
|
|
import torch.distributed.elastic.timer as timer |
|
|
from torch.distributed.elastic import events |
|
|
|
|
|
from torch.distributed.elastic.agent.server.api import ( |
|
|
RunResult, |
|
|
SimpleElasticAgent, |
|
|
WorkerGroup, |
|
|
WorkerSpec, |
|
|
WorkerState, |
|
|
) |
|
|
from torch.distributed.elastic.events.api import EventMetadataValue |
|
|
from torch.distributed.elastic.metrics.api import prof |
|
|
from torch.distributed.elastic.multiprocessing import PContext, start_processes |
|
|
from torch.distributed.elastic.utils import macros |
|
|
from torch.distributed.elastic.utils.logging import get_logger |
|
|
|
|
|
log = get_logger() |
|
|
|
|
|
__all__ = [ |
|
|
"LocalElasticAgent", |
|
|
"TORCHELASTIC_ENABLE_FILE_TIMER", |
|
|
"TORCHELASTIC_TIMER_FILE", |
|
|
] |
|
|
|
|
|
TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER" |
|
|
TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE" |
|
|
|
|
|
class LocalElasticAgent(SimpleElasticAgent): |
|
|
""" |
|
|
An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` |
|
|
that handles host-local workers. |
|
|
This agent is deployed per host and is configured to spawn ``n`` workers. |
|
|
When using GPUs, ``n`` maps to the number of GPUs available on the host. |
|
|
|
|
|
The local agent does not communicate to other local agents deployed on |
|
|
other hosts, even if the workers may communicate inter-host. The worker id |
|
|
is interpreted to be a local process. The agent starts and stops all worker |
|
|
processes as a single unit. |
|
|
|
|
|
|
|
|
The worker function and argument passed to the worker function must be |
|
|
python multiprocessing compatible. To pass multiprocessing data structures |
|
|
to the workers you may create the data structure in the same multiprocessing |
|
|
context as the specified ``start_method`` and pass it as a function argument. |
|
|
|
|
|
The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait |
|
|
for other agents to finish. This acts as a safety net to handle cases where |
|
|
workers finish at different times, to prevent agents from viewing workers |
|
|
that finished early as a scale-down event. It is strongly advised that the |
|
|
user code deal with ensuring that workers are terminated in a synchronous |
|
|
manner rather than relying on the exit_barrier_timeout. |
|
|
|
|
|
A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an |
|
|
environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has |
|
|
been defined in the ```LocalElasticAgent``` process. |
|
|
Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE``` |
|
|
can be set with a unique file name for the named pipe. If the environment |
|
|
variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent``` |
|
|
will internally create a unique file name and set it to the environment |
|
|
variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will |
|
|
be propagated to the worker processes to allow them to connect to the same |
|
|
named pipe that ```LocalElasticAgent``` uses. |
|
|
|
|
|
Example launching function |
|
|
|
|
|
:: |
|
|
|
|
|
def trainer(args) -> str: |
|
|
return "do train" |
|
|
|
|
|
def main(): |
|
|
start_method="spawn" |
|
|
shared_queue= multiprocessing.get_context(start_method).Queue() |
|
|
spec = WorkerSpec( |
|
|
role="trainer", |
|
|
local_world_size=nproc_per_process, |
|
|
entrypoint=trainer, |
|
|
args=("foobar",), |
|
|
...<OTHER_PARAMS...>) |
|
|
agent = LocalElasticAgent(spec, start_method) |
|
|
results = agent.run() |
|
|
|
|
|
if results.is_failed(): |
|
|
print("trainer failed") |
|
|
else: |
|
|
print(f"rank 0 return value: {results.return_values[0]}") |
|
|
# prints -> rank 0 return value: do train |
|
|
|
|
|
Example launching binary |
|
|
|
|
|
:: |
|
|
|
|
|
def main(): |
|
|
spec = WorkerSpec( |
|
|
role="trainer", |
|
|
local_world_size=nproc_per_process, |
|
|
entrypoint="/usr/local/bin/trainer", |
|
|
args=("--trainer_args", "foobar"), |
|
|
...<OTHER_PARAMS...>) |
|
|
agent = LocalElasticAgent(spec) |
|
|
results = agent.run() |
|
|
|
|
|
if not results.is_failed(): |
|
|
print("binary launches do not have return values") |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
spec: WorkerSpec, |
|
|
start_method="spawn", |
|
|
exit_barrier_timeout: float = 300, |
|
|
log_dir: Optional[str] = None, |
|
|
): |
|
|
super().__init__(spec, exit_barrier_timeout) |
|
|
self._start_method = start_method |
|
|
self._pcontext: Optional[PContext] = None |
|
|
rdzv_run_id = spec.rdzv_handler.get_run_id() |
|
|
self._log_dir = self._make_log_dir(log_dir, rdzv_run_id) |
|
|
self._worker_watchdog: Optional[timer.FileTimerServer] = None |
|
|
|
|
|
def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): |
|
|
base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_") |
|
|
os.makedirs(base_log_dir, exist_ok=True) |
|
|
dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir) |
|
|
log.info(f"log directory set to: {dir}") |
|
|
return dir |
|
|
|
|
|
def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: |
|
|
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER |
|
|
watchdog_enabled = os.getenv(enable_watchdog_env_name) |
|
|
watchdog_file_env_name = TORCHELASTIC_TIMER_FILE |
|
|
watchdog_file_path = os.getenv(watchdog_file_env_name) |
|
|
if watchdog_enabled is not None and str(watchdog_enabled) == "1": |
|
|
if watchdog_file_path is None: |
|
|
watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4()) |
|
|
log.info(f"Starting a FileTimerServer with {watchdog_file_path} ...") |
|
|
self._worker_watchdog = timer.FileTimerServer( |
|
|
file_path=watchdog_file_path, |
|
|
max_interval=0.1, |
|
|
daemon=True, |
|
|
log_event=self._log_watchdog_event) |
|
|
self._worker_watchdog.start() |
|
|
log.info("FileTimerServer started") |
|
|
else: |
|
|
log.info(f"Environment variable '{enable_watchdog_env_name}' not found. Do not start FileTimerServer.") |
|
|
|
|
|
if watchdog_file_path is not None: |
|
|
for _, worker_env in envs.items(): |
|
|
worker_env[watchdog_file_env_name] = watchdog_file_path |
|
|
|
|
|
|
|
|
def _get_fq_hostname(self) -> str: |
|
|
return socket.getfqdn(socket.gethostname()) |
|
|
|
|
|
def _log_watchdog_event( |
|
|
self, |
|
|
name: str, |
|
|
request: Optional[timer.FileTimerRequest], |
|
|
) -> None: |
|
|
wg = self._worker_group |
|
|
spec = wg.spec |
|
|
md = { |
|
|
"watchdog_event": name |
|
|
} |
|
|
if request is not None: |
|
|
md["worker_pid"] = str(request.worker_pid) |
|
|
md["scope_id"] = request.scope_id |
|
|
md["expiration_time"] = str(request.expiration_time) |
|
|
md["signal"] = str(request.signal) |
|
|
md_str = json.dumps(md) |
|
|
state = "RUNNING" |
|
|
metadata: Dict[str, EventMetadataValue] = { |
|
|
"run_id": spec.rdzv_handler.get_run_id(), |
|
|
"global_rank": None, |
|
|
"group_rank": wg.group_rank, |
|
|
"worker_id": None, |
|
|
"role": spec.role, |
|
|
"hostname": self._get_fq_hostname(), |
|
|
"state": state, |
|
|
"total_run_time": self._total_execution_time, |
|
|
"rdzv_backend": spec.rdzv_handler.get_backend(), |
|
|
"raw_error": None, |
|
|
"metadata": md_str, |
|
|
"agent_restarts": spec.max_restarts - self._remaining_restarts, |
|
|
} |
|
|
|
|
|
|
|
|
event = events.Event( |
|
|
name=name, source=events.EventSource.AGENT, metadata=metadata |
|
|
) |
|
|
events.record(event) |
|
|
|
|
|
|
|
|
|
|
|
@prof |
|
|
def _stop_workers(self, worker_group: WorkerGroup) -> None: |
|
|
self._shutdown() |
|
|
|
|
|
|
|
|
|
|
|
@prof |
|
|
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: |
|
|
spec = worker_group.spec |
|
|
store = worker_group.store |
|
|
assert store is not None |
|
|
master_addr, master_port = super()._get_master_addr_port(store) |
|
|
restart_count = spec.max_restarts - self._remaining_restarts |
|
|
|
|
|
use_agent_store = spec.rdzv_handler.get_backend() == "static" |
|
|
|
|
|
args: Dict[int, Tuple] = {} |
|
|
envs: Dict[int, Dict[str, str]] = {} |
|
|
for worker in worker_group.workers: |
|
|
local_rank = worker.local_rank |
|
|
worker_env = { |
|
|
"LOCAL_RANK": str(local_rank), |
|
|
"RANK": str(worker.global_rank), |
|
|
"GROUP_RANK": str(worker_group.group_rank), |
|
|
"ROLE_RANK": str(worker.role_rank), |
|
|
"ROLE_NAME": spec.role, |
|
|
"LOCAL_WORLD_SIZE": str(spec.local_world_size), |
|
|
"WORLD_SIZE": str(worker.world_size), |
|
|
"GROUP_WORLD_SIZE": str(worker_group.group_world_size), |
|
|
"ROLE_WORLD_SIZE": str(worker.role_world_size), |
|
|
"MASTER_ADDR": master_addr, |
|
|
"MASTER_PORT": str(master_port), |
|
|
"TORCHELASTIC_RESTART_COUNT": str(restart_count), |
|
|
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), |
|
|
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), |
|
|
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), |
|
|
"NCCL_ASYNC_ERROR_HANDLING": os.getenv( |
|
|
"NCCL_ASYNC_ERROR_HANDLING", str(1) |
|
|
), |
|
|
} |
|
|
if "OMP_NUM_THREADS" in os.environ: |
|
|
worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] |
|
|
|
|
|
envs[local_rank] = worker_env |
|
|
worker_args = list(spec.args) |
|
|
worker_args = macros.substitute(worker_args, str(local_rank)) |
|
|
args[local_rank] = tuple(worker_args) |
|
|
|
|
|
|
|
|
|
|
|
attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}") |
|
|
shutil.rmtree(attempt_log_dir, ignore_errors=True) |
|
|
os.makedirs(attempt_log_dir) |
|
|
|
|
|
self._setup_local_watchdog(envs=envs) |
|
|
|
|
|
assert spec.entrypoint is not None |
|
|
self._pcontext = start_processes( |
|
|
name=spec.role, |
|
|
entrypoint=spec.entrypoint, |
|
|
args=args, |
|
|
envs=envs, |
|
|
log_dir=attempt_log_dir, |
|
|
start_method=self._start_method, |
|
|
redirects=spec.redirects, |
|
|
tee=spec.tee, |
|
|
) |
|
|
|
|
|
return self._pcontext.pids() |
|
|
|
|
|
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None: |
|
|
if self._worker_watchdog is not None: |
|
|
self._worker_watchdog.stop() |
|
|
self._worker_watchdog = None |
|
|
if self._pcontext: |
|
|
self._pcontext.close(death_sig) |
|
|
|
|
|
|
|
|
|
|
|
@prof |
|
|
def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: |
|
|
role = worker_group.spec.role |
|
|
worker_pids = {w.id for w in worker_group.workers} |
|
|
assert self._pcontext is not None |
|
|
pc_pids = set(self._pcontext.pids().values()) |
|
|
if worker_pids != pc_pids: |
|
|
log.error( |
|
|
f"[{role}] worker pids do not match process_context pids." |
|
|
f" Expected: {worker_pids}, actual: {pc_pids}" |
|
|
) |
|
|
return RunResult(state=WorkerState.UNKNOWN) |
|
|
|
|
|
result = self._pcontext.wait(0) |
|
|
if result: |
|
|
if result.is_failed(): |
|
|
|
|
|
worker_failures = {} |
|
|
for local_rank, failure in result.failures.items(): |
|
|
worker = worker_group.workers[local_rank] |
|
|
worker_failures[worker.global_rank] = failure |
|
|
return RunResult( |
|
|
state=WorkerState.FAILED, |
|
|
failures=worker_failures, |
|
|
) |
|
|
else: |
|
|
|
|
|
workers_ret_vals = {} |
|
|
for local_rank, ret_val in result.return_values.items(): |
|
|
worker = worker_group.workers[local_rank] |
|
|
workers_ret_vals[worker.global_rank] = ret_val |
|
|
return RunResult( |
|
|
state=WorkerState.SUCCEEDED, |
|
|
return_values=workers_ret_vals, |
|
|
) |
|
|
else: |
|
|
return RunResult(state=WorkerState.HEALTHY) |
|
|
|