"""Distributed and strategy routing helpers.""" from __future__ import annotations import os import torch def get_training_strategy(model_size_b: float) -> dict[str, object]: """Choose a training mode based on the visible hardware.""" n_gpus = torch.cuda.device_count() world_size = int(os.environ.get("WORLD_SIZE", "1")) n_nodes = max(1, world_size // max(n_gpus, 1)) if n_gpus else 1 has_cuda = torch.cuda.is_available() has_mps = torch.backends.mps.is_available() if not has_cuda and not has_mps: return {"mode": "cpu", "backend": None, "tp": 1, "pp": 1, "zero": 0} if has_mps: return {"mode": "mps-single", "backend": None, "tp": 1, "pp": 1, "zero": 0} vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if n_nodes > 1: if model_size_b <= 1.0: return {"mode": "ddp", "backend": "nccl", "tp": 1, "pp": 1, "zero": 2} return {"mode": "fsdp", "backend": "nccl", "tp": 2, "pp": 1, "zero": 3} if n_gpus > 1: if model_size_b <= 1.0: return {"mode": "ddp", "backend": "nccl", "tp": 1, "pp": 1, "zero": 1} return {"mode": "fsdp", "backend": "nccl", "tp": 2, "pp": 1, "zero": 2} if vram_gb >= 40: return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 0} if vram_gb >= 24: return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 1} if vram_gb >= 16: return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 2} return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 3}