| import os
|
| import torch
|
| import torch.distributed as dist
|
| from torch import inf
|
| import shutil
|
| import random
|
| import numpy as np
|
|
|
| def setup_seed(seed):
|
| random.seed(seed)
|
| os.environ['PYTHONHASHSEED'] = str(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
| torch.cuda.manual_seed(seed)
|
| torch.cuda.manual_seed_all(seed)
|
| torch.backends.cudnn.benchmark = False
|
| torch.backends.cudnn.deterministic = True
|
|
|
| def save_checkpoint(args, state, is_best):
|
| os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
|
| filename = os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth.tar')
|
| torch.save(state, filename)
|
| if is_best:
|
| shutil.copyfile(filename, os.path.join(args.output_dir, 'checkpoints', '{}_best.pth.tar'.format(args.run_name)))
|
|
|
| def get_grad_norm(parameters, norm_type=2):
|
| if isinstance(parameters, torch.Tensor):
|
| parameters = [parameters]
|
| parameters = list(filter(lambda p: p.grad is not None, parameters))
|
| norm_type = float(norm_type)
|
| total_norm = 0
|
| for p in parameters:
|
| param_norm = p.grad.data.norm(norm_type)
|
| total_norm += param_norm.item() ** norm_type
|
| total_norm = total_norm ** (1.0 / norm_type)
|
| return total_norm
|
|
|
| def reduce_tensor(tensor):
|
| rt = tensor.clone()
|
| dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
| rt /= dist.get_world_size()
|
| return rt
|
|
|
| def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
| if isinstance(parameters, torch.Tensor):
|
| parameters = [parameters]
|
| parameters = [p for p in parameters if p.grad is not None]
|
| norm_type = float(norm_type)
|
| if len(parameters) == 0:
|
| return torch.tensor(0.0)
|
| device = parameters[0].grad.device
|
| if norm_type == inf:
|
| total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| else:
|
| total_norm = torch.norm(
|
| torch.stack(
|
| [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
|
| ),
|
| norm_type,
|
| )
|
| return total_norm
|
|
|
|
|
| class NativeScalerWithGradNormCount:
|
| state_dict_key = "amp_scaler"
|
|
|
| def __init__(self):
|
| self._scaler = torch.cuda.amp.GradScaler()
|
|
|
| def __call__(
|
| self,
|
| loss,
|
| optimizer,
|
| clip_grad=None,
|
| parameters=None,
|
| create_graph=False,
|
| update_grad=True,
|
| ):
|
| self._scaler.scale(loss).backward(create_graph=create_graph)
|
| if update_grad:
|
| if clip_grad is not None:
|
| assert parameters is not None
|
| self._scaler.unscale_(
|
| optimizer
|
| )
|
| norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
| else:
|
| self._scaler.unscale_(optimizer)
|
| norm = ampscaler_get_grad_norm(parameters)
|
| self._scaler.step(optimizer)
|
| self._scaler.update()
|
| else:
|
| norm = None
|
| return norm
|
|
|
| def state_dict(self):
|
| return self._scaler.state_dict()
|
|
|
| def load_state_dict(self, state_dict):
|
| self._scaler.load_state_dict(state_dict)
|
|
|
| def auto_resume_helper(output_dir):
|
| checkpoints = os.listdir(output_dir)
|
| checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith("pth")]
|
| print(f"All checkpoints founded in {output_dir}: {checkpoints}")
|
| if len(checkpoints) > 0:
|
| latest_checkpoint = max(
|
| [os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime
|
| )
|
| print(f"The latest checkpoint founded: {latest_checkpoint}")
|
| resume_file = latest_checkpoint
|
| else:
|
| resume_file = None
|
| return resume_file
|
|
|
|
|
| if __name__ == '__main__':
|
| pass
|
|
|