|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import datetime |
|
|
import socket |
|
|
from contextlib import closing |
|
|
|
|
|
import torch.distributed as dist |
|
|
from torch.distributed.elastic.utils.logging import get_logger |
|
|
|
|
|
|
|
|
log = get_logger() |
|
|
|
|
|
_ADDRESS_IN_USE = "Address already in use" |
|
|
_SOCKET_TIMEOUT = "Socket Timeout" |
|
|
|
|
|
_MEMBER_CHECKIN = "_tcp_store/num_members" |
|
|
_LAST_MEMBER_CHECKIN = "_tcp_store/last_member" |
|
|
|
|
|
|
|
|
def create_c10d_store( |
|
|
is_server: bool, |
|
|
server_addr: str, |
|
|
server_port: int = -1, |
|
|
world_size: int = 1, |
|
|
timeout: float = (60 * 10), |
|
|
wait_for_workers: bool = True, |
|
|
retries=3, |
|
|
): |
|
|
if server_port == -1 and world_size > 1: |
|
|
raise ValueError( |
|
|
f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}" |
|
|
) |
|
|
|
|
|
if server_port != -1: |
|
|
log.info(f"sever_port: {server_port}, specified, ignoring retries") |
|
|
|
|
|
|
|
|
attempt = retries if server_port == -1 else 1 |
|
|
while True: |
|
|
if server_port != -1: |
|
|
port = server_port |
|
|
else: |
|
|
port = get_free_port() |
|
|
|
|
|
log.info( |
|
|
f"Creating c10d store on {server_addr}:{port}\n" |
|
|
f" world_size : {world_size}\n" |
|
|
f" is_server : {is_server}\n" |
|
|
f" timeout(sec): {timeout}\n" |
|
|
) |
|
|
|
|
|
try: |
|
|
store = dist.TCPStore( |
|
|
host_name=server_addr, |
|
|
port=port, |
|
|
world_size=world_size, |
|
|
is_master=is_server, |
|
|
timeout=datetime.timedelta(seconds=timeout), |
|
|
wait_for_workers=wait_for_workers, |
|
|
) |
|
|
|
|
|
if wait_for_workers: |
|
|
_check_full_rank(store, world_size) |
|
|
log.info("Successfully created c10d store") |
|
|
return store |
|
|
except RuntimeError as e: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if str(e) == _ADDRESS_IN_USE: |
|
|
if attempt < retries: |
|
|
log.warning( |
|
|
f"port: {port} already in use, attempt: [{attempt}/{retries}]" |
|
|
) |
|
|
attempt += 1 |
|
|
else: |
|
|
raise RuntimeError( |
|
|
f"on {server_addr}, port: {port} already in use" |
|
|
) from e |
|
|
else: |
|
|
raise |
|
|
|
|
|
|
|
|
def _check_full_rank(store, world_size): |
|
|
idx = store.add(_MEMBER_CHECKIN, 1) |
|
|
if idx == world_size: |
|
|
store.set(_LAST_MEMBER_CHECKIN, "<val_ignored>") |
|
|
|
|
|
try: |
|
|
store.get(_LAST_MEMBER_CHECKIN) |
|
|
except RuntimeError as e: |
|
|
if str(e) == _SOCKET_TIMEOUT: |
|
|
raise TimeoutError( |
|
|
f"timed out waiting for all {world_size} members to join" |
|
|
) from e |
|
|
else: |
|
|
raise |
|
|
|
|
|
|
|
|
def get_free_port(): |
|
|
sock = get_socket_with_port() |
|
|
with closing(sock): |
|
|
return sock.getsockname()[1] |
|
|
|
|
|
|
|
|
def get_socket_with_port() -> socket.socket: |
|
|
""" |
|
|
Returns a free port on localhost that is "reserved" by binding a temporary |
|
|
socket on it. Close the socket before passing the port to the entity |
|
|
that requires it. Usage example |
|
|
|
|
|
:: |
|
|
|
|
|
sock = _get_socket_with_port() |
|
|
with closing(sock): |
|
|
port = sock.getsockname()[1] |
|
|
sock.close() |
|
|
# there is still a race-condition that some other process |
|
|
# may grab this port before func() runs |
|
|
func(port) |
|
|
""" |
|
|
|
|
|
addrs = socket.getaddrinfo( |
|
|
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM |
|
|
) |
|
|
for addr in addrs: |
|
|
family, type, proto, _, _ = addr |
|
|
s = socket.socket(family, type, proto) |
|
|
try: |
|
|
s.bind(("localhost", 0)) |
|
|
s.listen(0) |
|
|
return s |
|
|
except OSError as e: |
|
|
s.close() |
|
|
log.info("Socket creation attempt failed.", exc_info=e) |
|
|
raise RuntimeError("Failed to create a socket") |
|
|
|