|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch.optim.optimizer import Optimizer |
|
|
from torch.optim import Adam, RMSprop, SGD, Adadelta, Adagrad, Adamax, AdamW, ASGD |
|
|
from torch_optimizer import ( |
|
|
AccSGD, |
|
|
AdaBound, |
|
|
AdaMod, |
|
|
DiffGrad, |
|
|
Lamb, |
|
|
NovoGrad, |
|
|
PID, |
|
|
QHAdam, |
|
|
QHM, |
|
|
RAdam, |
|
|
SGDW, |
|
|
Yogi, |
|
|
Ranger, |
|
|
RangerQH, |
|
|
RangerVA, |
|
|
) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"AccSGD", |
|
|
"AdaBound", |
|
|
"AdaMod", |
|
|
"DiffGrad", |
|
|
"Lamb", |
|
|
"NovoGrad", |
|
|
"PID", |
|
|
"QHAdam", |
|
|
"QHM", |
|
|
"RAdam", |
|
|
"SGDW", |
|
|
"Yogi", |
|
|
"Ranger", |
|
|
"RangerQH", |
|
|
"RangerVA", |
|
|
"Adam", |
|
|
"RMSprop", |
|
|
"SGD", |
|
|
"Adadelta", |
|
|
"Adagrad", |
|
|
"Adamax", |
|
|
"AdamW", |
|
|
"ASGD", |
|
|
"make_optimizer", |
|
|
"get", |
|
|
] |
|
|
|
|
|
|
|
|
def make_optimizer(params, optim_name="adam", **kwargs): |
|
|
""" |
|
|
|
|
|
Args: |
|
|
params (iterable): Output of `nn.Module.parameters()`. |
|
|
optimizer (str or :class:`torch.optim.Optimizer`): Identifier understood |
|
|
by :func:`~.get`. |
|
|
**kwargs (dict): keyword arguments for the optimizer. |
|
|
|
|
|
Returns: |
|
|
torch.optim.Optimizer |
|
|
Examples |
|
|
>>> from torch import nn |
|
|
>>> model = nn.Sequential(nn.Linear(10, 10)) |
|
|
>>> optimizer = make_optimizer(model.parameters(), optimizer='sgd', |
|
|
>>> lr=1e-3) |
|
|
""" |
|
|
return get(optim_name)(params, **kwargs) |
|
|
|
|
|
|
|
|
def register_optimizer(custom_opt): |
|
|
"""Register a custom opt, gettable with `optimzers.get`. |
|
|
|
|
|
Args: |
|
|
custom_opt: Custom optimizer to register. |
|
|
|
|
|
""" |
|
|
if ( |
|
|
custom_opt.__name__ in globals().keys() |
|
|
or custom_opt.__name__.lower() in globals().keys() |
|
|
): |
|
|
raise ValueError( |
|
|
f"Activation {custom_opt.__name__} already exists. Choose another name." |
|
|
) |
|
|
globals().update({custom_opt.__name__: custom_opt}) |
|
|
|
|
|
|
|
|
def get(identifier): |
|
|
"""Returns an optimizer function from a string. Returns its input if it |
|
|
is callable (already a :class:`torch.optim.Optimizer` for example). |
|
|
|
|
|
Args: |
|
|
identifier (str or Callable): the optimizer identifier. |
|
|
|
|
|
Returns: |
|
|
:class:`torch.optim.Optimizer` or None |
|
|
""" |
|
|
if isinstance(identifier, Optimizer): |
|
|
return identifier |
|
|
elif isinstance(identifier, str): |
|
|
to_get = {k.lower(): v for k, v in globals().items()} |
|
|
cls = to_get.get(identifier.lower()) |
|
|
if cls is None: |
|
|
raise ValueError(f"Could not interpret optimizer : {str(identifier)}") |
|
|
return cls |
|
|
raise ValueError(f"Could not interpret optimizer : {str(identifier)}") |
|
|
|