import os import torch import torch.distributed as dist from torch import inf import shutil import random import numpy as np def setup_seed(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def save_checkpoint(args, state, is_best): os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True) filename = os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth.tar') torch.save(state, filename) if is_best: shutil.copyfile(filename, os.path.join(args.output_dir, 'checkpoints', '{}_best.pth.tar'.format(args.run_name))) def get_grad_norm(parameters, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) norm_type = float(norm_type) total_norm = 0 for p in parameters: param_norm = p.grad.data.norm(norm_type) total_norm += param_norm.item() ** norm_type total_norm = total_norm ** (1.0 / norm_type) return total_norm def reduce_tensor(tensor): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= dist.get_world_size() return rt def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor: if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] norm_type = float(norm_type) if len(parameters) == 0: return torch.tensor(0.0) device = parameters[0].grad.device if norm_type == inf: total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) else: total_norm = torch.norm( torch.stack( [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] ), norm_type, ) return total_norm class NativeScalerWithGradNormCount: state_dict_key = "amp_scaler" def __init__(self): self._scaler = torch.cuda.amp.GradScaler() def __call__( self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True, ): self._scaler.scale(loss).backward(create_graph=create_graph) if update_grad: if clip_grad is not None: assert parameters is not None self._scaler.unscale_( optimizer ) # unscale the gradients of optimizer's assigned params in-place norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) else: self._scaler.unscale_(optimizer) norm = ampscaler_get_grad_norm(parameters) self._scaler.step(optimizer) self._scaler.update() else: norm = None return norm def state_dict(self): return self._scaler.state_dict() def load_state_dict(self, state_dict): self._scaler.load_state_dict(state_dict) def auto_resume_helper(output_dir): checkpoints = os.listdir(output_dir) checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith("pth")] print(f"All checkpoints founded in {output_dir}: {checkpoints}") if len(checkpoints) > 0: latest_checkpoint = max( [os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime ) print(f"The latest checkpoint founded: {latest_checkpoint}") resume_file = latest_checkpoint else: resume_file = None return resume_file if __name__ == '__main__': pass