|
|
""" |
|
|
Distributed training utilities for multi-GPU training. |
|
|
|
|
|
Supports both DDP (Distributed Data Parallel) and FSDP (Fully Sharded Data Parallel). |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
from typing import Optional |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def setup_ddp(rank: int, world_size: int, backend: str = "nccl"): |
|
|
""" |
|
|
Initialize distributed training environment. |
|
|
|
|
|
Args: |
|
|
rank: Process rank (0 to world_size-1) |
|
|
world_size: Total number of processes |
|
|
backend: Communication backend ('nccl' for GPU, 'gloo' for CPU) |
|
|
""" |
|
|
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost") |
|
|
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12355") |
|
|
|
|
|
dist.init_process_group( |
|
|
backend=backend, |
|
|
rank=rank, |
|
|
world_size=world_size, |
|
|
) |
|
|
|
|
|
torch.cuda.set_device(rank) |
|
|
logger.info(f"DDP initialized: rank={rank}, world_size={world_size}, backend={backend}") |
|
|
|
|
|
|
|
|
def cleanup_ddp(): |
|
|
"""Clean up distributed training environment.""" |
|
|
if dist.is_initialized(): |
|
|
dist.destroy_process_group() |
|
|
logger.info("DDP cleaned up") |
|
|
|
|
|
|
|
|
def get_ddp_info() -> dict: |
|
|
""" |
|
|
Get current DDP configuration. |
|
|
|
|
|
Returns: |
|
|
Dict with rank, world_size, is_initialized, etc. |
|
|
""" |
|
|
return { |
|
|
"is_initialized": dist.is_initialized(), |
|
|
"rank": dist.get_rank() if dist.is_initialized() else 0, |
|
|
"world_size": dist.get_world_size() if dist.is_initialized() else 1, |
|
|
"backend": dist.get_backend() if dist.is_initialized() else None, |
|
|
} |
|
|
|
|
|
|
|
|
def wrap_model_ddp( |
|
|
model: torch.nn.Module, |
|
|
device: str = "cuda", |
|
|
find_unused_parameters: bool = False, |
|
|
gradient_as_bucket_view: bool = True, |
|
|
) -> torch.nn.Module: |
|
|
""" |
|
|
Wrap model with DDP for distributed training. |
|
|
|
|
|
Args: |
|
|
model: Model to wrap |
|
|
device: Device to use |
|
|
find_unused_parameters: Whether to find unused parameters (slower but more flexible) |
|
|
gradient_as_bucket_view: Use gradient as bucket view for memory efficiency |
|
|
|
|
|
Returns: |
|
|
DDP-wrapped model |
|
|
""" |
|
|
if not dist.is_initialized(): |
|
|
logger.warning("DDP not initialized, returning unwrapped model") |
|
|
return model |
|
|
|
|
|
rank = dist.get_rank() |
|
|
if device == "cuda": |
|
|
torch.cuda.set_device(rank) |
|
|
device_id = rank |
|
|
else: |
|
|
device_id = None |
|
|
|
|
|
ddp_model = DDP( |
|
|
model, |
|
|
device_ids=[device_id] if device_id is not None else None, |
|
|
output_device=device_id, |
|
|
find_unused_parameters=find_unused_parameters, |
|
|
gradient_as_bucket_view=gradient_as_bucket_view, |
|
|
) |
|
|
|
|
|
logger.info(f"Model wrapped with DDP (rank={rank})") |
|
|
return ddp_model |
|
|
|
|
|
|
|
|
def create_distributed_sampler( |
|
|
dataset, |
|
|
shuffle: bool = True, |
|
|
seed: int = 0, |
|
|
) -> Optional[DistributedSampler]: |
|
|
""" |
|
|
Create distributed sampler for dataset. |
|
|
|
|
|
Args: |
|
|
dataset: Dataset to sample from |
|
|
shuffle: Whether to shuffle |
|
|
seed: Random seed |
|
|
|
|
|
Returns: |
|
|
DistributedSampler if DDP is initialized, None otherwise |
|
|
""" |
|
|
if not dist.is_initialized(): |
|
|
return None |
|
|
|
|
|
sampler = DistributedSampler( |
|
|
dataset, |
|
|
num_replicas=dist.get_world_size(), |
|
|
rank=dist.get_rank(), |
|
|
shuffle=shuffle, |
|
|
seed=seed, |
|
|
) |
|
|
|
|
|
logger.info(f"Created DistributedSampler (rank={dist.get_rank()}/{dist.get_world_size()})") |
|
|
return sampler |
|
|
|
|
|
|
|
|
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
All-reduce tensor and compute mean across all processes. |
|
|
|
|
|
Args: |
|
|
tensor: Tensor to reduce |
|
|
|
|
|
Returns: |
|
|
Mean value across all processes |
|
|
""" |
|
|
if not dist.is_initialized(): |
|
|
return tensor |
|
|
|
|
|
dist.all_reduce(tensor, op=dist.ReduceOp.SUM) |
|
|
tensor /= dist.get_world_size() |
|
|
return tensor |
|
|
|
|
|
|
|
|
def save_checkpoint_ddp( |
|
|
model: torch.nn.Module, |
|
|
optimizer, |
|
|
scheduler, |
|
|
epoch: int, |
|
|
loss: float, |
|
|
checkpoint_path: str, |
|
|
is_main_process: bool = True, |
|
|
): |
|
|
""" |
|
|
Save checkpoint (only on main process to avoid conflicts). |
|
|
|
|
|
Args: |
|
|
model: Model to save |
|
|
optimizer: Optimizer state |
|
|
scheduler: Scheduler state |
|
|
epoch: Current epoch |
|
|
loss: Current loss |
|
|
checkpoint_path: Path to save checkpoint |
|
|
is_main_process: Whether this is the main process (rank 0) |
|
|
""" |
|
|
if is_main_process: |
|
|
|
|
|
if isinstance(model, DDP): |
|
|
model_state = model.module.state_dict() |
|
|
else: |
|
|
model_state = model.state_dict() |
|
|
|
|
|
torch.save( |
|
|
{ |
|
|
"epoch": epoch, |
|
|
"model_state_dict": model_state, |
|
|
"optimizer_state_dict": optimizer.state_dict(), |
|
|
"scheduler_state_dict": scheduler.state_dict(), |
|
|
"loss": loss, |
|
|
}, |
|
|
checkpoint_path, |
|
|
) |
|
|
logger.info(f"Saved checkpoint to {checkpoint_path}") |
|
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
def load_checkpoint_ddp( |
|
|
model: torch.nn.Module, |
|
|
checkpoint_path: str, |
|
|
device: str = "cuda", |
|
|
) -> dict: |
|
|
""" |
|
|
Load checkpoint for distributed training. |
|
|
|
|
|
Args: |
|
|
model: Model to load into |
|
|
checkpoint_path: Path to checkpoint |
|
|
device: Device to load on |
|
|
|
|
|
Returns: |
|
|
Checkpoint dict |
|
|
""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
|
|
|
if isinstance(model, DDP): |
|
|
model.module.load_state_dict(checkpoint["model_state_dict"]) |
|
|
else: |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
|
|
|
logger.info(f"Loaded checkpoint from {checkpoint_path}") |
|
|
return checkpoint |
|
|
|
|
|
|
|
|
def run_distributed_training( |
|
|
rank: int, |
|
|
world_size: int, |
|
|
train_fn, |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Helper to run distributed training function. |
|
|
|
|
|
Args: |
|
|
rank: Process rank |
|
|
world_size: Total number of processes |
|
|
train_fn: Training function to run |
|
|
*args, **kwargs: Arguments to pass to train_fn |
|
|
""" |
|
|
try: |
|
|
setup_ddp(rank, world_size) |
|
|
train_fn(rank, world_size, *args, **kwargs) |
|
|
finally: |
|
|
cleanup_ddp() |
|
|
|
|
|
|
|
|
def launch_distributed_training( |
|
|
world_size: int, |
|
|
train_fn, |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Launch distributed training using torch.multiprocessing. |
|
|
|
|
|
Args: |
|
|
world_size: Number of GPUs to use |
|
|
train_fn: Training function (should accept rank and world_size as first args) |
|
|
*args, **kwargs: Additional arguments for train_fn |
|
|
""" |
|
|
import torch.multiprocessing as mp |
|
|
|
|
|
mp.spawn( |
|
|
run_distributed_training, |
|
|
args=(world_size, train_fn) + args, |
|
|
nprocs=world_size, |
|
|
join=True, |
|
|
) |
|
|
|