| |
| |
| |
| |
| |
| |
|
|
| from torch.optim.optimizer import Optimizer |
| from torch.optim import Adam, RMSprop, SGD, Adadelta, Adagrad, Adamax, AdamW, ASGD |
| from torch_optimizer import ( |
| AccSGD, |
| AdaBound, |
| AdaMod, |
| DiffGrad, |
| Lamb, |
| NovoGrad, |
| PID, |
| QHAdam, |
| QHM, |
| RAdam, |
| SGDW, |
| Yogi, |
| Ranger, |
| RangerQH, |
| RangerVA, |
| ) |
|
|
|
|
| __all__ = [ |
| "AccSGD", |
| "AdaBound", |
| "AdaMod", |
| "DiffGrad", |
| "Lamb", |
| "NovoGrad", |
| "PID", |
| "QHAdam", |
| "QHM", |
| "RAdam", |
| "SGDW", |
| "Yogi", |
| "Ranger", |
| "RangerQH", |
| "RangerVA", |
| "Adam", |
| "RMSprop", |
| "SGD", |
| "Adadelta", |
| "Adagrad", |
| "Adamax", |
| "AdamW", |
| "ASGD", |
| "make_optimizer", |
| "get", |
| ] |
|
|
|
|
| def make_optimizer(params, optim_name="adam", **kwargs): |
| """ |
| |
| Args: |
| params (iterable): Output of `nn.Module.parameters()`. |
| optimizer (str or :class:`torch.optim.Optimizer`): Identifier understood |
| by :func:`~.get`. |
| **kwargs (dict): keyword arguments for the optimizer. |
| |
| Returns: |
| torch.optim.Optimizer |
| Examples |
| >>> from torch import nn |
| >>> model = nn.Sequential(nn.Linear(10, 10)) |
| >>> optimizer = make_optimizer(model.parameters(), optimizer='sgd', |
| >>> lr=1e-3) |
| """ |
| return get(optim_name)(params, **kwargs) |
|
|
|
|
| def register_optimizer(custom_opt): |
| """Register a custom opt, gettable with `optimzers.get`. |
| |
| Args: |
| custom_opt: Custom optimizer to register. |
| |
| """ |
| if ( |
| custom_opt.__name__ in globals().keys() |
| or custom_opt.__name__.lower() in globals().keys() |
| ): |
| raise ValueError( |
| f"Activation {custom_opt.__name__} already exists. Choose another name." |
| ) |
| globals().update({custom_opt.__name__: custom_opt}) |
|
|
|
|
| def get(identifier): |
| """Returns an optimizer function from a string. Returns its input if it |
| is callable (already a :class:`torch.optim.Optimizer` for example). |
| |
| Args: |
| identifier (str or Callable): the optimizer identifier. |
| |
| Returns: |
| :class:`torch.optim.Optimizer` or None |
| """ |
| if isinstance(identifier, Optimizer): |
| return identifier |
| elif isinstance(identifier, str): |
| to_get = {k.lower(): v for k, v in globals().items()} |
| cls = to_get.get(identifier.lower()) |
| if cls is None: |
| raise ValueError(f"Could not interpret optimizer : {str(identifier)}") |
| return cls |
| raise ValueError(f"Could not interpret optimizer : {str(identifier)}") |
|
|