Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
| # Copyright (c) Meta Platforms, Inc. and affiliates | |
| # All rights reserved. | |
| # | |
| # | |
| from typing import Tuple | |
| from fairseq2.logging import get_log_writer | |
| from fairseq2.optim.lr_scheduler import ( | |
| AbstractLRScheduler, | |
| CosineAnnealingLR, | |
| MyleLR, | |
| NoopLR, | |
| PolynomialDecayLR, | |
| TriStageLR, | |
| ) | |
| from torch.optim import Optimizer | |
| logger = get_log_writer(__name__) | |
| def build_lr_scheduler( | |
| optimizer: Optimizer, | |
| lr: float, | |
| warmup_steps: int, | |
| start_lr: float = 1e-7, | |
| final_lr: float = 1e-5, | |
| max_steps: int = 10_000, | |
| stage_ratio: Tuple[float, ...] = (0.1, 0.4, 0.5), | |
| schedule: str = "myle", | |
| ) -> AbstractLRScheduler: | |
| assert schedule in [ | |
| "noop", | |
| "myle", | |
| "cosine", | |
| "wsd", | |
| "polynomial", | |
| ], ( | |
| f"Cannot recognize the learing rate schedule {schedule}, only noop, myle, cosine and wsd are supported" | |
| ) | |
| assert lr > 0, "The learning reate should be strictly positive" | |
| lr_scheduler: AbstractLRScheduler | |
| if schedule == "noop": | |
| lr_scheduler = NoopLR(optimizer) | |
| elif schedule == "myle": | |
| lr_scheduler = MyleLR( | |
| optimizer, | |
| num_warmup_steps=warmup_steps, | |
| start_lr=[start_lr], | |
| ) | |
| elif schedule == "cosine": | |
| lr_scheduler = CosineAnnealingLR( | |
| optimizer, | |
| cycle_len=max_steps - warmup_steps + 1, | |
| num_warmup_steps=warmup_steps, | |
| start_lr=[start_lr], | |
| final_lr=[final_lr], | |
| cycle_mul=1.0, | |
| lr_mul=1.0, | |
| ) | |
| elif schedule == "wsd": | |
| assert lr > start_lr, ( | |
| f"the starting learning rate {start_lr} should be lesser than the main lr {lr}" | |
| ) | |
| start_lr_scale = start_lr / lr | |
| assert lr > final_lr, ( | |
| f"the final learning rate {final_lr} should be lesser than the main lr {lr}" | |
| ) | |
| final_lr_scale = final_lr / lr | |
| lr_scheduler = TriStageLR( | |
| optimizer, | |
| max_steps, | |
| stage_ratio=stage_ratio, # type: ignore | |
| start_lr_scale=start_lr_scale, | |
| final_lr_scale=final_lr_scale, | |
| ) | |
| elif schedule == "polynomial": | |
| lr_scheduler = PolynomialDecayLR( | |
| optimizer, | |
| max_steps, | |
| warmup_steps, | |
| power=200, | |
| start_lr=start_lr, | |
| final_lr=final_lr, | |
| ) | |
| return lr_scheduler | |