| import torch
|
| import numpy as np
|
| import logging
|
| from math import inf
|
| import math
|
|
|
| class NativeScalerWithGradNormCount:
|
| state_dict_key = "amp_scaler"
|
|
|
| def __init__(self):
|
| self._scaler = torch.amp.GradScaler('cuda')
|
|
|
| def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
| self._scaler.scale(loss).backward(create_graph=create_graph)
|
| if update_grad:
|
| if clip_grad is not None:
|
| assert parameters is not None
|
| self._scaler.unscale_(optimizer)
|
| norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
| else:
|
| self._scaler.unscale_(optimizer)
|
| norm = get_grad_norm_(parameters)
|
| self._scaler.step(optimizer)
|
| self._scaler.update()
|
| else:
|
| norm = None
|
| return norm
|
|
|
| def state_dict(self):
|
| return self._scaler.state_dict()
|
|
|
| def load_state_dict(self, state_dict):
|
| self._scaler.load_state_dict(state_dict)
|
|
|
|
|
| def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
| if isinstance(parameters, torch.Tensor):
|
| parameters = [parameters]
|
| parameters = [p for p in parameters if p.grad is not None]
|
| norm_type = float(norm_type)
|
| if len(parameters) == 0:
|
| return torch.tensor(0.)
|
| device = parameters[0].grad.device
|
| if norm_type == inf:
|
| total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| else:
|
| total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
| return total_norm
|
|
|
|
|
| def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
|
| start_warmup_value=0, warmup_steps=-1):
|
| warmup_schedule = np.array([])
|
| warmup_iters = warmup_epochs * niter_per_ep
|
| if warmup_steps > 0:
|
| warmup_iters = warmup_steps
|
| logging.info("Set warmup steps = %d" % warmup_iters)
|
| if warmup_epochs > 0:
|
| warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
|
|
| iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
| schedule = np.array(
|
| [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
|
|
|
| schedule = np.concatenate((warmup_schedule, schedule))
|
|
|
| assert len(schedule) == epochs * niter_per_ep
|
| return schedule |