| | """ |
| | Optimizer and schedulers |
| | """ |
| | import torch |
| | import torch.nn as nn |
| | from torch.optim import Optimizer |
| | from torch.optim.lr_scheduler import LRScheduler |
| |
|
| |
|
| | def get_optimizer(optim: str, model: nn.Module, **kwargs: any) -> Optimizer: |
| | """ |
| | Return training optimizer |
| | """ |
| | if optim == 'sgd': |
| | return torch.optim.SGD(model.parameters(), **kwargs) |
| | elif optim == 'adam': |
| | return torch.optim.Adam(model.parameters(), **kwargs) |
| | elif optim in ['adamw', 'adamw_torch']: |
| | return torch.optim.AdamW(model.parameters(), **kwargs) |
| | elif optim == 'adamw_torch_fused': |
| | return torch.optim.AdamW(model.parameters(), **kwargs, fused=True) |
| | elif optim == 'adafactor': |
| | from transformers import Adafactor |
| | kwargs['relative_step'] = False |
| | return Adafactor(model.parameters(), **kwargs) |
| | else: |
| | raise NotImplementedError(f"{optim} optimizer not implemented sorry.") |
| |
|
| |
|
| | def get_scheduler(lr_scheduler_type: str, optimizer: Optimizer, |
| | **kwargs: any) -> LRScheduler: |
| | """ |
| | Return learning rate scheduler |
| | """ |
| | if lr_scheduler_type in ['plateau', 'reduce_lr_on_plateau']: |
| | from torch.optim.lr_scheduler import ReduceLROnPlateau |
| | return ReduceLROnPlateau(optimizer=optimizer, **kwargs) |
| | |
| | elif lr_scheduler_type == 'cosine_warmup': |
| | from transformers import get_cosine_schedule_with_warmup |
| | return get_cosine_schedule_with_warmup(optimizer=optimizer, **kwargs) |
| | |
| | elif lr_scheduler_type in ['linear_warmup', 'linear']: |
| | from transformers import get_linear_schedule_with_warmup |
| | return get_linear_schedule_with_warmup(optimizer=optimizer, **kwargs) |
| | |
| | else: |
| | return None |