File size: 6,136 Bytes
cbe6208 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
"""Optimizer factory-function classes.
"""
from abc import ABC, abstractmethod
import torch
class AbstractOptimizer(ABC):
"""Abstract class for optimizer
Optimizer classes will be used by model like:
> OptimizerGenerator = AbstractOptimizer()
> optimizer = OptimizerGenerator(model)
The returned object `optimizer` must be something that may be returned by `pytorch_lightning`'s
`configure_optimizers()` method.
See :
https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers
"""
@abstractmethod
def __call__(self):
"""Abstract call"""
pass
class Adam(AbstractOptimizer):
"""Adam optimizer"""
def __init__(self, lr=0.0005, **kwargs):
"""Adam optimizer"""
self.lr = lr
self.kwargs = kwargs
def __call__(self, model):
"""Return optimizer"""
return torch.optim.Adam(model.parameters(), lr=self.lr, **self.kwargs)
class AdamW(AbstractOptimizer):
"""AdamW optimizer"""
def __init__(self, lr=0.0005, **kwargs):
"""AdamW optimizer"""
self.lr = lr
self.kwargs = kwargs
def __call__(self, model):
"""Return optimizer"""
return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs)
def find_submodule_parameters(model, search_modules):
"""Finds all parameters within given submodule types
Args:
model: torch Module to search through
search_modules: List of submodule types to search for
"""
if isinstance(model, search_modules):
return model.parameters()
children = list(model.children())
if len(children) == 0:
return []
else:
params = []
for c in children:
params += find_submodule_parameters(c, search_modules)
return params
def find_other_than_submodule_parameters(model, ignore_modules):
"""Finds all parameters not with given submodule types
Args:
model: torch Module to search through
ignore_modules: List of submodule types to ignore
"""
if isinstance(model, ignore_modules):
return []
children = list(model.children())
if len(children) == 0:
return model.parameters()
else:
params = []
for c in children:
params += find_other_than_submodule_parameters(c, ignore_modules)
return params
class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
"""AdamW optimizer and reduce on plateau scheduler"""
def __init__(
self, lr=0.0005, weight_decay=0.01, patience=3, factor=0.5, threshold=2e-4, **opt_kwargs
):
"""AdamW optimizer and reduce on plateau scheduler"""
self.lr = lr
self.weight_decay = weight_decay
self.patience = patience
self.factor = factor
self.threshold = threshold
self.opt_kwargs = opt_kwargs
def __call__(self, model):
"""Return optimizer"""
search_modules = (torch.nn.Embedding,)
no_decay = find_submodule_parameters(model, search_modules)
decay = find_other_than_submodule_parameters(model, search_modules)
optim_groups = [
{"params": decay, "weight_decay": self.weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
)
sch = {
"scheduler": sch,
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
}
return [opt], [sch]
class AdamWReduceLROnPlateau(AbstractOptimizer):
"""AdamW optimizer and reduce on plateau scheduler"""
def __init__(
self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, step_freq=None, **opt_kwargs
):
"""AdamW optimizer and reduce on plateau scheduler"""
self._lr = lr
self.patience = patience
self.factor = factor
self.threshold = threshold
self.step_freq = step_freq
self.opt_kwargs = opt_kwargs
def _call_multi(self, model):
remaining_params = {k: p for k, p in model.named_parameters()}
group_args = []
for key in self._lr.keys():
if key == "default":
continue
submodule_params = []
for param_name in list(remaining_params.keys()):
if param_name.startswith(key):
submodule_params += [remaining_params.pop(param_name)]
group_args += [{"params": submodule_params, "lr": self._lr[key]}]
remaining_params = [p for k, p in remaining_params.items()]
group_args += [{"params": remaining_params}]
opt = torch.optim.AdamW(
group_args, lr=self._lr["default"] if model.lr is None else model.lr, **self.opt_kwargs
)
sch = {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
),
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
}
return [opt], [sch]
def __call__(self, model):
"""Return optimizer"""
if not isinstance(self._lr, float):
return self._call_multi(model)
else:
default_lr = self._lr if model.lr is None else model.lr
opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs)
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
)
sch = {
"scheduler": sch,
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
}
return [opt], [sch]
|