pvnet_nl / pvnet /optimizers.py
peterdudfield's picture
Upload folder using huggingface_hub
cbe6208
raw
history blame
6.14 kB
"""Optimizer factory-function classes.
"""
from abc import ABC, abstractmethod
import torch
class AbstractOptimizer(ABC):
"""Abstract class for optimizer
Optimizer classes will be used by model like:
> OptimizerGenerator = AbstractOptimizer()
> optimizer = OptimizerGenerator(model)
The returned object `optimizer` must be something that may be returned by `pytorch_lightning`'s
`configure_optimizers()` method.
See :
https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers
"""
@abstractmethod
def __call__(self):
"""Abstract call"""
pass
class Adam(AbstractOptimizer):
"""Adam optimizer"""
def __init__(self, lr=0.0005, **kwargs):
"""Adam optimizer"""
self.lr = lr
self.kwargs = kwargs
def __call__(self, model):
"""Return optimizer"""
return torch.optim.Adam(model.parameters(), lr=self.lr, **self.kwargs)
class AdamW(AbstractOptimizer):
"""AdamW optimizer"""
def __init__(self, lr=0.0005, **kwargs):
"""AdamW optimizer"""
self.lr = lr
self.kwargs = kwargs
def __call__(self, model):
"""Return optimizer"""
return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs)
def find_submodule_parameters(model, search_modules):
"""Finds all parameters within given submodule types
Args:
model: torch Module to search through
search_modules: List of submodule types to search for
"""
if isinstance(model, search_modules):
return model.parameters()
children = list(model.children())
if len(children) == 0:
return []
else:
params = []
for c in children:
params += find_submodule_parameters(c, search_modules)
return params
def find_other_than_submodule_parameters(model, ignore_modules):
"""Finds all parameters not with given submodule types
Args:
model: torch Module to search through
ignore_modules: List of submodule types to ignore
"""
if isinstance(model, ignore_modules):
return []
children = list(model.children())
if len(children) == 0:
return model.parameters()
else:
params = []
for c in children:
params += find_other_than_submodule_parameters(c, ignore_modules)
return params
class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
"""AdamW optimizer and reduce on plateau scheduler"""
def __init__(
self, lr=0.0005, weight_decay=0.01, patience=3, factor=0.5, threshold=2e-4, **opt_kwargs
):
"""AdamW optimizer and reduce on plateau scheduler"""
self.lr = lr
self.weight_decay = weight_decay
self.patience = patience
self.factor = factor
self.threshold = threshold
self.opt_kwargs = opt_kwargs
def __call__(self, model):
"""Return optimizer"""
search_modules = (torch.nn.Embedding,)
no_decay = find_submodule_parameters(model, search_modules)
decay = find_other_than_submodule_parameters(model, search_modules)
optim_groups = [
{"params": decay, "weight_decay": self.weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
)
sch = {
"scheduler": sch,
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
}
return [opt], [sch]
class AdamWReduceLROnPlateau(AbstractOptimizer):
"""AdamW optimizer and reduce on plateau scheduler"""
def __init__(
self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, step_freq=None, **opt_kwargs
):
"""AdamW optimizer and reduce on plateau scheduler"""
self._lr = lr
self.patience = patience
self.factor = factor
self.threshold = threshold
self.step_freq = step_freq
self.opt_kwargs = opt_kwargs
def _call_multi(self, model):
remaining_params = {k: p for k, p in model.named_parameters()}
group_args = []
for key in self._lr.keys():
if key == "default":
continue
submodule_params = []
for param_name in list(remaining_params.keys()):
if param_name.startswith(key):
submodule_params += [remaining_params.pop(param_name)]
group_args += [{"params": submodule_params, "lr": self._lr[key]}]
remaining_params = [p for k, p in remaining_params.items()]
group_args += [{"params": remaining_params}]
opt = torch.optim.AdamW(
group_args, lr=self._lr["default"] if model.lr is None else model.lr, **self.opt_kwargs
)
sch = {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
),
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
}
return [opt], [sch]
def __call__(self, model):
"""Return optimizer"""
if not isinstance(self._lr, float):
return self._call_multi(model)
else:
default_lr = self._lr if model.lr is None else model.lr
opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs)
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
)
sch = {
"scheduler": sch,
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
}
return [opt], [sch]