File size: 1,592 Bytes
ef18673 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | """Distributed and strategy routing helpers."""
from __future__ import annotations
import os
import torch
def get_training_strategy(model_size_b: float) -> dict[str, object]:
"""Choose a training mode based on the visible hardware."""
n_gpus = torch.cuda.device_count()
world_size = int(os.environ.get("WORLD_SIZE", "1"))
n_nodes = max(1, world_size // max(n_gpus, 1)) if n_gpus else 1
has_cuda = torch.cuda.is_available()
has_mps = torch.backends.mps.is_available()
if not has_cuda and not has_mps:
return {"mode": "cpu", "backend": None, "tp": 1, "pp": 1, "zero": 0}
if has_mps:
return {"mode": "mps-single", "backend": None, "tp": 1, "pp": 1, "zero": 0}
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if n_nodes > 1:
if model_size_b <= 1.0:
return {"mode": "ddp", "backend": "nccl", "tp": 1, "pp": 1, "zero": 2}
return {"mode": "fsdp", "backend": "nccl", "tp": 2, "pp": 1, "zero": 3}
if n_gpus > 1:
if model_size_b <= 1.0:
return {"mode": "ddp", "backend": "nccl", "tp": 1, "pp": 1, "zero": 1}
return {"mode": "fsdp", "backend": "nccl", "tp": 2, "pp": 1, "zero": 2}
if vram_gb >= 40:
return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 0}
if vram_gb >= 24:
return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 1}
if vram_gb >= 16:
return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 2}
return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 3}
|