File size: 2,436 Bytes
b5a0bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#
from typing import Tuple
from fairseq2.logging import get_log_writer
from fairseq2.optim.lr_scheduler import (
AbstractLRScheduler,
CosineAnnealingLR,
MyleLR,
NoopLR,
PolynomialDecayLR,
TriStageLR,
)
from torch.optim import Optimizer
logger = get_log_writer(__name__)
def build_lr_scheduler(
optimizer: Optimizer,
lr: float,
warmup_steps: int,
start_lr: float = 1e-7,
final_lr: float = 1e-5,
max_steps: int = 10_000,
stage_ratio: Tuple[float, ...] = (0.1, 0.4, 0.5),
schedule: str = "myle",
) -> AbstractLRScheduler:
assert schedule in [
"noop",
"myle",
"cosine",
"wsd",
"polynomial",
], (
f"Cannot recognize the learing rate schedule {schedule}, only noop, myle, cosine and wsd are supported"
)
assert lr > 0, "The learning reate should be strictly positive"
lr_scheduler: AbstractLRScheduler
if schedule == "noop":
lr_scheduler = NoopLR(optimizer)
elif schedule == "myle":
lr_scheduler = MyleLR(
optimizer,
num_warmup_steps=warmup_steps,
start_lr=[start_lr],
)
elif schedule == "cosine":
lr_scheduler = CosineAnnealingLR(
optimizer,
cycle_len=max_steps - warmup_steps + 1,
num_warmup_steps=warmup_steps,
start_lr=[start_lr],
final_lr=[final_lr],
cycle_mul=1.0,
lr_mul=1.0,
)
elif schedule == "wsd":
assert lr > start_lr, (
f"the starting learning rate {start_lr} should be lesser than the main lr {lr}"
)
start_lr_scale = start_lr / lr
assert lr > final_lr, (
f"the final learning rate {final_lr} should be lesser than the main lr {lr}"
)
final_lr_scale = final_lr / lr
lr_scheduler = TriStageLR(
optimizer,
max_steps,
stage_ratio=stage_ratio, # type: ignore
start_lr_scale=start_lr_scale,
final_lr_scale=final_lr_scale,
)
elif schedule == "polynomial":
lr_scheduler = PolynomialDecayLR(
optimizer,
max_steps,
warmup_steps,
power=200,
start_lr=start_lr,
final_lr=final_lr,
)
return lr_scheduler
|