lglg666's picture
Upload folder using huggingface_hub
6766eda verified
from __future__ import annotations
import os
import typing
import torch
import torch.distributed as dist
import torch.utils.data
from torch.nn.parallel import DistributedDataParallel
class Accelerator:
"""
Simplified accelerator that mirrors the behaviour of the minicpm-audio
training utilities. It initializes a distributed process group when
``torchrun`` is used and exposes helpers for AMP, gradient scaling and
preparing models/dataloaders for DDP.
"""
def __init__(self, amp: bool = False):
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
if self.world_size > 1 and not dist.is_initialized():
dist.init_process_group("nccl", init_method="env://")
self.rank = dist.get_rank() if dist.is_initialized() else 0
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
self.amp = amp
class DummyScaler:
def step(self, optimizer):
optimizer.step()
def scale(self, loss):
return loss
def unscale_(self, optimizer):
return optimizer
def update(self):
pass
self.scaler = torch.amp.GradScaler() if amp else DummyScaler()
self.device_ctx = (
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
def __enter__(self):
if self.device_ctx is not None:
self.device_ctx.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.device_ctx is not None:
self.device_ctx.__exit__(exc_type, exc_value, traceback)
# ------------------------------------------------------------------ #
# Model helpers
# ------------------------------------------------------------------ #
def prepare_model(self, model: torch.nn.Module, **kwargs):
model = model.to(self.device)
if self.world_size > 1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, device_ids=[self.local_rank], **kwargs)
return model
@property
def device(self):
if torch.cuda.is_available():
return torch.device("cuda", self.local_rank)
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
# ------------------------------------------------------------------ #
# AMP helpers
# ------------------------------------------------------------------ #
def autocast(self, *args, **kwargs):
return torch.cuda.amp.autocast(self.amp, *args, **kwargs)
def backward(self, loss: torch.Tensor):
self.scaler.scale(loss).backward()
def step(self, optimizer: torch.optim.Optimizer):
self.scaler.step(optimizer)
def update(self):
self.scaler.update()
# ------------------------------------------------------------------ #
# Data helpers
# ------------------------------------------------------------------ #
def prepare_dataloader(
self,
dataset: typing.Iterable,
*,
batch_size: int,
num_workers: int = 0,
shuffle: bool = True,
collate_fn=None,
drop_last: bool = False,
) -> torch.utils.data.DataLoader:
if self.world_size > 1:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle
)
shuffle = False
else:
sampler = None
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last=drop_last,
pin_memory=True,
)
@staticmethod
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
return model.module if hasattr(model, "module") else model