File size: 6,412 Bytes
3d79eb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#
import logging
import os
import random
import subprocess
import warnings
from datetime import timedelta
from functools import partial
from typing import Any, List, Literal, Optional, Set, Tuple, Type
import submitit
import torch
import torch.distributed as dist
from fairseq2.gang import Gang, ProcessGroupGang
from fairseq2.logging import get_log_writer
from fairseq2.nn.fsdp import (
FSDP_LOW_MEMORY_POLICY,
FSDP_STANDARD_MEMORY_POLICY,
FSDP_VERY_LOW_MEMORY_POLICY,
FSDPMemoryPolicy,
FSDPWrapPolicy,
)
from fairseq2.nn.transformer import (
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.nn import Module
logger = get_log_writer(__name__)
SUPPORTED_FSDP_MEMORY_POLICIES = Literal["standard", "low", "very_low"]
SUPPORTED_FSDP_WRAP_POLICIES = Literal["layer", "stack", "model"]
def get_fsdp_memory_policy(
policy: SUPPORTED_FSDP_MEMORY_POLICIES = "standard",
) -> FSDPMemoryPolicy:
fsdp_memory_policy: FSDPMemoryPolicy
if policy == "standard":
fsdp_memory_policy = FSDP_STANDARD_MEMORY_POLICY
elif policy == "low":
fsdp_memory_policy = FSDP_LOW_MEMORY_POLICY
elif policy == "very_low":
fsdp_memory_policy = FSDP_VERY_LOW_MEMORY_POLICY
else:
raise ValueError("Unsupported policy {policy}. Choose from {}")
return fsdp_memory_policy
def get_fsdp_wrap_policy(
model: Module, wrap_granularity: SUPPORTED_FSDP_WRAP_POLICIES = "layer"
) -> Tuple[Optional[FSDPWrapPolicy], Optional[List[Module]]]:
"""Return the FSDP wrap policy for ``model`` along with ignored modules.
:param model:
The model to be wrapped.
:param wrap_granularity:
The granularity at which to wrap modules of ``model``.
- 'layer': Wraps individual layers (e.g. :class:`TransformerDecoderLayer`).
- 'stack': Wraps layer stacks (e.g. :class:`TransformerDecoder`).
- 'model': Wraps ``model`` only.
Copied over from fs2 to experiment easily with fsdp wrap policies
"""
if wrap_granularity == "model":
return None, None
kls: Set[Type[Module]]
if wrap_granularity == "stack":
kls = {TransformerEncoder, TransformerDecoder}
elif wrap_granularity == "layer":
kls = {
TransformerEncoderLayer,
TransformerDecoderLayer,
}
else:
raise ValueError(
f"`wrap_granularity` must be 'layer', 'stack', or 'model', but is '{wrap_granularity}' instead."
)
wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=kls)
return wrap_policy, None
def init_process_group(config: Any, logger: logging.Logger) -> Gang:
if getattr(config, "use_submitit", True):
try:
submitit.helpers.TorchDistributedEnvironment().export(overwrite=True)
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
except RuntimeError:
warnings.warn(
"looks like you are not in a submitit/stopes job. \
You probably want to override use_submitit=false",
stacklevel=2,
)
timeout = timedelta(minutes=15)
gang = ProcessGroupGang.init_default_process_group(
ok_initialized=False,
timeout=timeout,
)
logger.info(f"Initialized gang with default process group (timeout={timeout})")
return gang
def is_torch_run() -> bool:
return os.environ.get("TORCHELASTIC_RUN_ID") is not None
def is_slurm_job() -> bool:
return "SLURM_JOB_ID" in os.environ
def get_global_rank() -> int:
if dist.is_initialized():
return dist.get_rank()
if is_torch_run():
return int(os.environ["RANK"])
if is_slurm_job():
return int(os.environ["SLURM_PROCID"])
return 0
def get_local_rank() -> int:
if is_torch_run():
return int(os.environ["LOCAL_RANK"])
if is_slurm_job():
return int(os.environ["SLURM_LOCALID"])
return 0
def get_world_size() -> int:
if dist.is_initialized():
return dist.get_world_size()
if is_torch_run():
return int(os.environ["WORLD_SIZE"])
if is_slurm_job():
return int(os.environ["SLURM_NTASKS"])
return 1
def get_master_addr() -> str:
if is_torch_run():
return os.environ["MASTER_ADDR"]
if is_slurm_job():
hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]
)
return hostnames.split()[0].decode("utf-8")
return "127.0.0.1"
def get_master_port(job_id: int) -> Optional[int]:
if is_torch_run():
return int(os.environ["MASTER_PORT"])
else:
MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000)
rng = random.Random(job_id)
return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
def init_torch_distributed(
backend: str = "cpu:gloo,cuda:nccl",
port: Optional[str] = None,
max_attempt: int = 5,
) -> None:
if dist.is_initialized():
return
os.environ["RANK"] = str(get_global_rank())
os.environ["WORLD_SIZE"] = str(get_world_size())
master_addr = get_master_addr()
# Allow max_attempt to be set directly via os environment variable
# TORCH_DISTRIBUTED_PORT_ATTEMPTS
if os.environ.get("TORCH_DISTRIBUTED_PORT_ATTEMPTS", None):
max_attempt = int(os.environ["TORCH_DISTRIBUTED_PORT_ATTEMPTS"])
attempt = 0
while True:
try:
os.environ["MASTER_ADDR"] = master_addr
if port is None:
port = str(
get_master_port(job_id=int(os.environ.get("SLURM_JOB_ID", -1)))
)
os.environ["MASTER_PORT"] = port
local_rank = get_local_rank()
if "nccl" in backend:
torch.cuda.set_device(local_rank)
timeout = timedelta(hours=10)
dist.init_process_group(backend=backend, timeout=timeout)
break
except (dist.DistNetworkError, RuntimeError) as e:
attempt += 1
if attempt == max_attempt:
raise RuntimeError(
"Failed to initialize torch.distributed after 5 max attempts"
) from e
|