|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import typing |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.utils.data |
|
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
|
|
|
|
|
class Accelerator: |
|
|
""" |
|
|
Simplified accelerator that mirrors the behaviour of the minicpm-audio |
|
|
training utilities. It initializes a distributed process group when |
|
|
``torchrun`` is used and exposes helpers for AMP, gradient scaling and |
|
|
preparing models/dataloaders for DDP. |
|
|
""" |
|
|
|
|
|
def __init__(self, amp: bool = False): |
|
|
self.world_size = int(os.getenv("WORLD_SIZE", "1")) |
|
|
|
|
|
if self.world_size > 1 and not dist.is_initialized(): |
|
|
dist.init_process_group("nccl", init_method="env://") |
|
|
|
|
|
self.rank = dist.get_rank() if dist.is_initialized() else 0 |
|
|
self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
|
|
self.amp = amp |
|
|
|
|
|
class DummyScaler: |
|
|
def step(self, optimizer): |
|
|
optimizer.step() |
|
|
|
|
|
def scale(self, loss): |
|
|
return loss |
|
|
|
|
|
def unscale_(self, optimizer): |
|
|
return optimizer |
|
|
|
|
|
def update(self): |
|
|
pass |
|
|
|
|
|
self.scaler = torch.amp.GradScaler() if amp else DummyScaler() |
|
|
self.device_ctx = ( |
|
|
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None |
|
|
) |
|
|
|
|
|
def __enter__(self): |
|
|
if self.device_ctx is not None: |
|
|
self.device_ctx.__enter__() |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
|
if self.device_ctx is not None: |
|
|
self.device_ctx.__exit__(exc_type, exc_value, traceback) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_model(self, model: torch.nn.Module, **kwargs): |
|
|
model = model.to(self.device) |
|
|
if self.world_size > 1: |
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
|
model = DistributedDataParallel(model, device_ids=[self.local_rank], **kwargs) |
|
|
return model |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda", self.local_rank) |
|
|
if torch.backends.mps.is_available(): |
|
|
return torch.device("mps") |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def autocast(self, *args, **kwargs): |
|
|
return torch.cuda.amp.autocast(self.amp, *args, **kwargs) |
|
|
|
|
|
def backward(self, loss: torch.Tensor): |
|
|
self.scaler.scale(loss).backward() |
|
|
|
|
|
def step(self, optimizer: torch.optim.Optimizer): |
|
|
self.scaler.step(optimizer) |
|
|
|
|
|
def update(self): |
|
|
self.scaler.update() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_dataloader( |
|
|
self, |
|
|
dataset: typing.Iterable, |
|
|
*, |
|
|
batch_size: int, |
|
|
num_workers: int = 0, |
|
|
shuffle: bool = True, |
|
|
collate_fn=None, |
|
|
drop_last: bool = False, |
|
|
) -> torch.utils.data.DataLoader: |
|
|
if self.world_size > 1: |
|
|
sampler = torch.utils.data.distributed.DistributedSampler( |
|
|
dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle |
|
|
) |
|
|
shuffle = False |
|
|
else: |
|
|
sampler = None |
|
|
|
|
|
return torch.utils.data.DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle if sampler is None else False, |
|
|
sampler=sampler, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collate_fn, |
|
|
drop_last=drop_last, |
|
|
pin_memory=True, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def unwrap(model: torch.nn.Module) -> torch.nn.Module: |
|
|
return model.module if hasattr(model, "module") else model |
|
|
|
|
|
|