| |
| |
| |
| |
| |
|
|
| import binascii |
| import logging |
| import os |
| import tempfile |
| from base64 import b64decode, b64encode |
| from datetime import timedelta |
| from typing import Any, Optional, Tuple, cast |
|
|
| from torch.distributed import FileStore, Store, TCPStore |
| from torch.distributed.elastic.events import ( |
| NodeState, |
| construct_and_record_rdzv_event, |
| ) |
|
|
| from .api import ( |
| RendezvousConnectionError, |
| RendezvousError, |
| RendezvousParameters, |
| RendezvousStateError, |
| ) |
| from .dynamic_rendezvous import RendezvousBackend, Token |
| from .utils import _matches_machine_hostname, parse_rendezvous_endpoint |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| class C10dRendezvousBackend(RendezvousBackend): |
| """Represents a C10d-backed rendezvous backend. |
| |
| Args: |
| store: |
| The :py:class:`torch.distributed.Store` instance to use to |
| communicate with the C10d store. |
| run_id: |
| The run id of the rendezvous. |
| """ |
|
|
| |
| _NULL_SENTINEL = "Y2FuaW1hZGFt" |
|
|
| _store: Store |
| _key: str |
|
|
| def __init__(self, store: Store, run_id: str) -> None: |
| if not run_id: |
| raise ValueError("The run id must be a non-empty string.") |
|
|
| self._store = store |
|
|
| self._key = "torch.rendezvous." + run_id |
|
|
| |
| |
| |
| |
| |
| |
| self._call_store("compare_set", self._key, "", self._NULL_SENTINEL) |
|
|
| @property |
| def name(self) -> str: |
| """See base class.""" |
| return "c10d" |
|
|
| def get_state(self) -> Optional[Tuple[bytes, Token]]: |
| """See base class.""" |
| base64_state: bytes = self._call_store("get", self._key) |
|
|
| return self._decode_state(base64_state) |
|
|
| def set_state( |
| self, state: bytes, token: Optional[Token] = None |
| ) -> Optional[Tuple[bytes, Token, bool]]: |
| """See base class.""" |
| base64_state_str: str = b64encode(state).decode() |
|
|
| if token: |
| |
| if not isinstance(token, bytes): |
| result = self.get_state() |
| if result is not None: |
| tmp = *result, False |
| |
| |
| return tmp |
| return None |
|
|
| token = token.decode() |
| else: |
| token = self._NULL_SENTINEL |
|
|
| base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str) |
|
|
| state_token_pair = self._decode_state(base64_state) |
| if state_token_pair is None: |
| return None |
|
|
| new_state, new_token = state_token_pair |
|
|
| |
| |
| |
| return new_state, new_token, new_state == state |
|
|
| def _call_store(self, store_op: str, *args, **kwargs) -> Any: |
| try: |
| return getattr(self._store, store_op)(*args, **kwargs) |
| except (ValueError, RuntimeError, TimeoutError) as exc: |
| raise RendezvousConnectionError( |
| "The connection to the C10d store has failed. See inner exception for details." |
| ) from exc |
|
|
| def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]: |
| if base64_state == self._NULL_SENTINEL.encode(): |
| return None |
|
|
| try: |
| state = b64decode(base64_state) |
| except binascii.Error as exc: |
| raise RendezvousStateError( |
| "The state object is corrupt. See inner exception for details." |
| ) from exc |
|
|
| return state, base64_state |
|
|
|
|
| def _create_tcp_store(params: RendezvousParameters) -> TCPStore: |
| host, port = parse_rendezvous_endpoint(params.endpoint, default_port=29400) |
|
|
| cfg_is_host = params.get_as_bool("is_host") |
| |
| |
| if cfg_is_host is not None: |
| is_host = cfg_is_host |
| |
| |
| else: |
| is_host = _matches_machine_hostname(host) |
|
|
| |
| read_timeout = cast(int, params.get_as_int("read_timeout", 60)) |
| if read_timeout <= 0: |
| raise ValueError("The read timeout must be a positive integer.") |
|
|
| |
| |
| for is_server in [is_host, False]: |
| try: |
| store = TCPStore( |
| host, port, is_master=is_server, timeout=timedelta(seconds=read_timeout) |
| ) |
|
|
| if is_server: |
| msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend." |
| construct_and_record_rdzv_event( |
| run_id=params.run_id, message=msg, node_state=NodeState.INIT |
| ) |
| log.info(msg) |
|
|
| break |
| except (ValueError, RuntimeError, TimeoutError) as exc: |
| |
| |
| |
| |
| |
|
|
| if not is_server or cfg_is_host is not None: |
| raise RendezvousConnectionError( |
| "The connection to the C10d store has failed. See inner exception for details." |
| ) from exc |
|
|
| return store |
|
|
|
|
| def _create_file_store(params: RendezvousParameters) -> FileStore: |
| |
| if params.endpoint: |
| path = params.endpoint |
| else: |
| try: |
| |
| |
| _, path = tempfile.mkstemp() |
| except OSError as exc: |
| raise RendezvousError( |
| "The file creation for C10d store has failed. See inner exception for details." |
| ) from exc |
|
|
| try: |
| store = FileStore(path) |
| except (ValueError, RuntimeError) as exc: |
| raise RendezvousConnectionError( |
| "The connection to the C10d store has failed. See inner exception for details." |
| ) from exc |
|
|
| return store |
|
|
|
|
| def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]: |
| """Creates a new :py:class:`C10dRendezvousBackend` from the specified |
| parameters. |
| |
| +--------------+-----------------------------------------------------------+ |
| | Parameter | Description | |
| +==============+===========================================================+ |
| | store_type | The type of the C10d store. The currently supported types | |
| | | are "tcp" and "file" which correspond to | |
| | | :py:class:`torch.distributed.TCPStore` and | |
| | | :py:class:`torch.distributed.FileStore`, respectively. | |
| | | Defaults to "tcp". | |
| +--------------+-----------------------------------------------------------+ |
| | read_timeout | The read timeout, in seconds, for store operations. | |
| | | Defaults to 60 seconds. | |
| | | | |
| | | Note this only applies to | |
| | | :py:class:`torch.distributed.TCPStore`. It is not relevant| |
| | | to :py:class:`torch.distributed.FileStore` which does not | |
| | | take in timeout as a parameter. | |
| +--------------+-----------------------------------------------------------+ |
| | is_host | A boolean value indicating whether this backend instance | |
| | | will host the C10d store. If not specified it will be | |
| | | inferred heuristically by matching the hostname or the IP | |
| | | address of this machine against the specified rendezvous | |
| | | endpoint. Defaults to ``None``. | |
| | | | |
| | | Note that this configuration option only applies to | |
| | | :py:class:`torch.distributed.TCPStore`. In normal | |
| | | circumstances you can safely skip it; the only time when | |
| | | it is needed is if its value cannot be correctly | |
| | | determined (e.g. the rendezvous endpoint has a CNAME as | |
| | | the hostname or does not match the FQDN of the machine). | |
| +--------------+-----------------------------------------------------------+ |
| """ |
| |
| |
| store_type = params.get("store_type", "tcp").strip().lower() |
| store: Store |
|
|
| try: |
| if store_type == "file": |
| store = _create_file_store(params) |
| elif store_type == "tcp": |
| store = _create_tcp_store(params) |
| else: |
| raise ValueError("Invalid store type given. Currently only supports file and tcp.") |
|
|
| backend = C10dRendezvousBackend(store, params.run_id) |
|
|
| except Exception as e: |
| construct_and_record_rdzv_event( |
| message=f"{type(e).__name__}: {str(e)}", |
| run_id=params.run_id, |
| node_state=NodeState.FAILED, |
| ) |
| raise |
|
|
| return backend, store |
|
|