|
|
import torch |
|
|
import math |
|
|
|
|
|
from torch.optim import Optimizer |
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
from torch.optim.adamw import adamw |
|
|
|
|
|
try: |
|
|
import deepspeed |
|
|
from deepspeed.ops.adam import FusedAdam |
|
|
from deepspeed.ops.adam import DeepSpeedCPUAdam |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
def get_optimizer(cfg, params): |
|
|
if cfg.optim.type == 'adam': |
|
|
return torch.optim.Adam( |
|
|
params=params, |
|
|
lr=cfg.optim.lr, |
|
|
weight_decay=cfg.optim.weight_decay, |
|
|
betas=(cfg.optim.beta1, cfg.optim.beta2) |
|
|
) |
|
|
elif cfg.optim.type == 'adamw': |
|
|
return AdamW( |
|
|
params=params, |
|
|
lr=cfg.optim.lr, |
|
|
weight_decay=cfg.optim.weight_decay, |
|
|
betas=(cfg.optim.beta1, cfg.optim.beta2) |
|
|
) |
|
|
elif cfg.type == 'fusedadam': |
|
|
return FusedAdam( |
|
|
params=params, |
|
|
lr=cfg.lr, |
|
|
weight_decay=cfg.weight_decay, |
|
|
betas=cfg.betas, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError('Optimizer not supported: %s' % cfg.type) |
|
|
|
|
|
|
|
|
class AdamW(torch.optim.AdamW): |
|
|
@torch.no_grad() |
|
|
def step(self, closure=None): |
|
|
"""Performs a single optimization step. |
|
|
|
|
|
Args: |
|
|
closure (callable, optional): A closure that reevaluates the model |
|
|
and returns the loss. |
|
|
""" |
|
|
self._cuda_graph_capture_health_check() |
|
|
|
|
|
loss = None |
|
|
if closure is not None: |
|
|
with torch.enable_grad(): |
|
|
loss = closure() |
|
|
|
|
|
for group in self.param_groups: |
|
|
params_with_grad = [] |
|
|
grads = [] |
|
|
exp_avgs = [] |
|
|
exp_avg_sqs = [] |
|
|
max_exp_avg_sqs = [] |
|
|
state_steps = [] |
|
|
amsgrad = group['amsgrad'] |
|
|
beta1, beta2 = group['betas'] |
|
|
|
|
|
for p in group['params']: |
|
|
if p.grad is None: |
|
|
continue |
|
|
params_with_grad.append(p) |
|
|
if p.grad.is_sparse: |
|
|
raise RuntimeError('AdamW does not support sparse gradients') |
|
|
grads.append(p.grad) |
|
|
|
|
|
state = self.state[p] |
|
|
|
|
|
|
|
|
if len(state) == 0: |
|
|
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ |
|
|
if self.defaults['capturable'] else torch.tensor(0.) |
|
|
|
|
|
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|
|
|
|
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|
if amsgrad: |
|
|
|
|
|
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|
|
|
|
exp_avgs.append(state['exp_avg']) |
|
|
exp_avg_sqs.append(state['exp_avg_sq']) |
|
|
|
|
|
if amsgrad: |
|
|
max_exp_avg_sqs.append(state['max_exp_avg_sq']) |
|
|
|
|
|
state_steps.append(state['step'].cpu()) |
|
|
|
|
|
adamw(params_with_grad, |
|
|
grads, |
|
|
exp_avgs, |
|
|
exp_avg_sqs, |
|
|
max_exp_avg_sqs, |
|
|
state_steps, |
|
|
amsgrad=amsgrad, |
|
|
beta1=beta1, |
|
|
beta2=beta2, |
|
|
lr=group['lr'], |
|
|
weight_decay=group['weight_decay'], |
|
|
eps=group['eps'], |
|
|
maximize=group['maximize'], |
|
|
foreach=group['foreach'], |
|
|
capturable=group['capturable']) |
|
|
|
|
|
return loss |
|
|
|
|
|
def get_scheduler(cfg, optimizer): |
|
|
if cfg.optim.scheduler is None: |
|
|
return BlackHole() |
|
|
elif cfg.optim.scheduler == 'plateau': |
|
|
return ( |
|
|
torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
|
optimizer, |
|
|
mode=cfg.mode, |
|
|
factor=cfg.factor, |
|
|
patience=cfg.patience, |
|
|
min_lr=cfg.min_lr, |
|
|
), |
|
|
{'monitor': "val/loss", 'interval': 'epoch'} |
|
|
) |
|
|
elif cfg.optim.scheduler == 'noam': |
|
|
return ( |
|
|
NoamScheduler( |
|
|
optimizer, |
|
|
lr=cfg.lr, |
|
|
warmup_steps=cfg.warmup_steps, |
|
|
model_size=cfg.model_size, |
|
|
warmup_init_lr=cfg.get('warmup_init_lr') |
|
|
), |
|
|
{'frequency': 1, 'interval': 'step'} |
|
|
) |
|
|
elif cfg.optim.scheduler == 'polynomial': |
|
|
return ( |
|
|
PolyNomialLRScheduler( |
|
|
optimizer, |
|
|
total_steps=cfg.training.max_steps, |
|
|
warmup_steps=cfg.training.warmup_steps, |
|
|
lr=cfg.optim.lr, |
|
|
lr_end=cfg.optim.lr_end, |
|
|
warmup_init_lr=cfg.optim.warmup_init_lr, |
|
|
power=cfg.optim.power |
|
|
), |
|
|
{'frequency': 1, 'interval': 'step'} |
|
|
) |
|
|
elif cfg.optim.scheduler == 'multistep': |
|
|
return torch.optim.lr_scheduler.MultiStepLR( |
|
|
optimizer, |
|
|
milestones=cfg.milestones, |
|
|
gamma=cfg.gamma, |
|
|
) |
|
|
elif cfg.optim.scheduler == 'exp': |
|
|
return torch.optim.lr_scheduler.ExponentialLR( |
|
|
optimizer, |
|
|
gamma=cfg.gamma, |
|
|
) |
|
|
elif cfg.optim.scheduler == 'progen_ft': |
|
|
sched = CosineToFrac( |
|
|
optimizer=optimizer, |
|
|
total_steps=cfg.training.max_steps, |
|
|
final_frac=0.2, |
|
|
) |
|
|
return (sched, {'frequency': 1, 'interval': 'step'}) |
|
|
elif cfg.optim.scheduler is None: |
|
|
return BlackHole() |
|
|
else: |
|
|
raise NotImplementedError('Scheduler not supported: %s' % cfg.optim.scheduler) |
|
|
|
|
|
|
|
|
class BlackHole(object): |
|
|
def __setattr__(self, name, value): |
|
|
pass |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
return self |
|
|
|
|
|
def __getattr__(self, name): |
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
def polynomial_lr_schedule(step, total_steps, warmup_steps, warmup_init_lr, lr, lr_end, power): |
|
|
if step < warmup_steps: |
|
|
return warmup_init_lr + (lr - warmup_init_lr) * step / warmup_steps |
|
|
elif step > total_steps: |
|
|
return lr_end |
|
|
else: |
|
|
return lr_end + (lr - lr_end) * (1 - (step - warmup_steps) / (total_steps - warmup_steps)) ** power |
|
|
|
|
|
class PolyNomialLRScheduler(LambdaLR): |
|
|
def __init__( |
|
|
self, |
|
|
optimizer: Optimizer, |
|
|
total_steps: int = 1000, |
|
|
warmup_steps: int = 0, |
|
|
lr: float = 0.00004, |
|
|
lr_end: float = 1e-5, |
|
|
warmup_init_lr: float = 1e-07, |
|
|
power: float = 1.0, |
|
|
) -> None: |
|
|
|
|
|
self.warmup_init_lr = warmup_init_lr |
|
|
self.warmup_steps = warmup_steps |
|
|
|
|
|
def lr_lambda(step): |
|
|
return polynomial_lr_schedule( |
|
|
step, total_steps, warmup_steps, warmup_init_lr, lr, lr_end, power |
|
|
) / lr |
|
|
|
|
|
super().__init__(optimizer, lr_lambda=lr_lambda) |
|
|
|
|
|
|
|
|
|
|
|
def cosine_frac_scheduler(step, total_steps, final_frac): |
|
|
s = min(max(step, 0), total_steps) |
|
|
cos = 0.5 * (1.0 + math.cos(math.pi * s / total_steps)) |
|
|
return final_frac + (1.0 - final_frac) * cos |
|
|
|
|
|
class CosineToFrac(LambdaLR): |
|
|
""" |
|
|
Cosine decay of the LR multiplier from 1.0 -> final_frac over total_steps (no warmup). |
|
|
For ProGen fine-tuning, final_frac=0.2 implements decay to lr/5. |
|
|
""" |
|
|
def __init__(self, optimizer, total_steps, final_frac=0.2): |
|
|
self.total_steps = max(int(total_steps), 1) |
|
|
self.final_frac = float(final_frac) |
|
|
|
|
|
def lr_lambda(step): |
|
|
return cosine_frac_scheduler( |
|
|
step=step, |
|
|
total_steps=self.total_steps, |
|
|
final_frac=self.final_frac |
|
|
) |
|
|
|
|
|
super().__init__(optimizer, lr_lambda=lr_lambda) |
|
|
|
|
|
|
|
|
|
|
|
def inverse_sqrt_lr_schedule(step, warmup_steps, warmup_init_lr, lr_step, decay_step): |
|
|
if step == 0: |
|
|
step = 1 |
|
|
if step < warmup_steps: |
|
|
return warmup_init_lr + lr_step * step |
|
|
else: |
|
|
return decay_step * step ** -0.5 |
|
|
|
|
|
|
|
|
class InverseSqrtLRScheduler(LambdaLR): |
|
|
def __init__( |
|
|
self, |
|
|
optimizer: Optimizer, |
|
|
warmup_steps: int = 0, |
|
|
lr: float = 5e-04, |
|
|
warmup_init_lr: float = 1e-07, |
|
|
) -> None: |
|
|
|
|
|
self.warmup_init_lr = warmup_init_lr |
|
|
self.warmup_steps = warmup_steps |
|
|
self.lr_step = (lr - warmup_init_lr) / warmup_steps |
|
|
self.decay_step = lr * warmup_steps ** 0.5 |
|
|
|
|
|
def lr_lambda(step): |
|
|
return inverse_sqrt_lr_schedule( |
|
|
step, warmup_steps, warmup_init_lr, self.lr_step, self.decay_step |
|
|
) / lr |
|
|
|
|
|
super().__init__(optimizer, lr_lambda=lr_lambda) |
|
|
|
|
|
|
|
|
def noam_lr_schedule(step, warmup_steps, factor, model_size): |
|
|
if step == 0: |
|
|
step = 1 |
|
|
return factor * (model_size ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5))) |
|
|
|
|
|
|
|
|
class NoamScheduler(LambdaLR): |
|
|
def __init__( |
|
|
self, |
|
|
optimizer: Optimizer, |
|
|
lr, |
|
|
warmup_init_lr, |
|
|
model_size: int = 128, |
|
|
warmup_steps: int = 0, |
|
|
factor: int = 2, |
|
|
) -> None: |
|
|
|
|
|
|
|
|
def lr_lambda(step): |
|
|
return noam_lr_schedule(step, warmup_steps, factor, model_size) / lr |
|
|
|
|
|
super().__init__(optimizer, lr_lambda=lr_lambda) |
|
|
|
|
|
|