|
|
from collections import defaultdict, abc as container_abcs |
|
|
import torch |
|
|
from copy import deepcopy |
|
|
from itertools import chain |
|
|
import warnings |
|
|
import functools |
|
|
|
|
|
__all__ = ['Optimizer'] |
|
|
|
|
|
class _RequiredParameter(object): |
|
|
"""Singleton class representing a required parameter for an Optimizer.""" |
|
|
def __repr__(self): |
|
|
return "<required parameter>" |
|
|
|
|
|
required = _RequiredParameter() |
|
|
|
|
|
|
|
|
def _use_grad_for_differentiable(func): |
|
|
def _use_grad(self, *args, **kwargs): |
|
|
prev_grad = torch.is_grad_enabled() |
|
|
try: |
|
|
torch.set_grad_enabled(self.defaults['differentiable']) |
|
|
ret = func(self, *args, **kwargs) |
|
|
finally: |
|
|
torch.set_grad_enabled(prev_grad) |
|
|
return ret |
|
|
return _use_grad |
|
|
|
|
|
|
|
|
class Optimizer(object): |
|
|
r"""Base class for all optimizers. |
|
|
|
|
|
.. warning:: |
|
|
Parameters need to be specified as collections that have a deterministic |
|
|
ordering that is consistent between runs. Examples of objects that don't |
|
|
satisfy those properties are sets and iterators over values of dictionaries. |
|
|
|
|
|
Args: |
|
|
params (iterable): an iterable of :class:`torch.Tensor` s or |
|
|
:class:`dict` s. Specifies what Tensors should be optimized. |
|
|
defaults: (dict): a dict containing default values of optimization |
|
|
options (used when a parameter group doesn't specify them). |
|
|
""" |
|
|
|
|
|
def __init__(self, params, defaults): |
|
|
torch._C._log_api_usage_once("python.optimizer") |
|
|
self.defaults = defaults |
|
|
|
|
|
self._hook_for_profile() |
|
|
|
|
|
if isinstance(params, torch.Tensor): |
|
|
raise TypeError("params argument given to the optimizer should be " |
|
|
"an iterable of Tensors or dicts, but got " + |
|
|
torch.typename(params)) |
|
|
|
|
|
self.state = defaultdict(dict) |
|
|
self.param_groups = [] |
|
|
|
|
|
param_groups = list(params) |
|
|
if len(param_groups) == 0: |
|
|
raise ValueError("optimizer got an empty parameter list") |
|
|
if not isinstance(param_groups[0], dict): |
|
|
param_groups = [{'params': param_groups}] |
|
|
|
|
|
for param_group in param_groups: |
|
|
self.add_param_group(param_group) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._warned_capturable_if_run_uncaptured = True |
|
|
|
|
|
|
|
|
def __getstate__(self): |
|
|
return { |
|
|
'defaults': self.defaults, |
|
|
'state': self.state, |
|
|
'param_groups': self.param_groups, |
|
|
} |
|
|
|
|
|
def __setstate__(self, state): |
|
|
self.__dict__.update(state) |
|
|
self._hook_for_profile() |
|
|
self.defaults.setdefault('differentiable', False) |
|
|
|
|
|
def __repr__(self): |
|
|
format_string = self.__class__.__name__ + ' (' |
|
|
for i, group in enumerate(self.param_groups): |
|
|
format_string += '\n' |
|
|
format_string += 'Parameter Group {0}\n'.format(i) |
|
|
for key in sorted(group.keys()): |
|
|
if key != 'params': |
|
|
format_string += ' {0}: {1}\n'.format(key, group[key]) |
|
|
format_string += ')' |
|
|
return format_string |
|
|
|
|
|
|
|
|
def _cuda_graph_capture_health_check(self): |
|
|
if torch.has_cuda and torch.cuda.is_available(): |
|
|
capturing = torch.cuda.is_current_stream_capturing() |
|
|
|
|
|
if capturing and not self.defaults['capturable']: |
|
|
raise RuntimeError("Attempting CUDA graph capture of step() for an instance of " + |
|
|
self.__class__.__name__ + |
|
|
" but this instance was constructed with capturable=False.") |
|
|
|
|
|
if ( |
|
|
(not getattr(self, "_warned_capturable_if_run_uncaptured", False)) |
|
|
and self.defaults["capturable"] |
|
|
and (not capturing) |
|
|
): |
|
|
print("Warning: This instance was constructed with capturable=True, but step() " + |
|
|
"is running without CUDA graph capture. If you never intend to graph-capture this " + |
|
|
"instance, capturable=True can impair performance, and you should set capturable=False.") |
|
|
self._warned_capturable_if_run_uncaptured = True |
|
|
|
|
|
def _optimizer_step_code(self): |
|
|
"""Entry point for `torch.profile.profiler`. |
|
|
|
|
|
When python tracing is enabled the profiler will hook into this |
|
|
function at the CPython level to inspect the optimizer's parameters and |
|
|
param groups. It is called it after `step()` since many optimizers |
|
|
lazily initialize state. |
|
|
|
|
|
This is a workaround due to lack of a proper step hook on the optimizer, |
|
|
and will be removed if it exists. |
|
|
""" |
|
|
pass |
|
|
|
|
|
def _hook_for_profile(self): |
|
|
self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__) |
|
|
|
|
|
def profile_hook_step(func): |
|
|
|
|
|
@functools.wraps(func) |
|
|
def wrapper(*args, **kwargs): |
|
|
obj, *_ = args |
|
|
profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__) |
|
|
with torch.autograd.profiler.record_function(profile_name): |
|
|
out = func(*args, **kwargs) |
|
|
obj._optimizer_step_code() |
|
|
return out |
|
|
|
|
|
return wrapper |
|
|
|
|
|
hooked = getattr(self.__class__.step, "hooked", None) |
|
|
if not hooked: |
|
|
self.__class__.step = profile_hook_step(self.__class__.step) |
|
|
self.__class__.step.hooked = True |
|
|
|
|
|
def state_dict(self): |
|
|
r"""Returns the state of the optimizer as a :class:`dict`. |
|
|
|
|
|
It contains two entries: |
|
|
|
|
|
* state - a dict holding current optimization state. Its content |
|
|
differs between optimizer classes. |
|
|
* param_groups - a list containing all parameter groups where each |
|
|
parameter group is a dict |
|
|
""" |
|
|
|
|
|
param_mappings = {} |
|
|
start_index = 0 |
|
|
|
|
|
def pack_group(group): |
|
|
nonlocal start_index |
|
|
packed = {k: v for k, v in group.items() if k != 'params'} |
|
|
param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index) |
|
|
if id(p) not in param_mappings}) |
|
|
packed['params'] = [param_mappings[id(p)] for p in group['params']] |
|
|
start_index += len(packed['params']) |
|
|
return packed |
|
|
param_groups = [pack_group(g) for g in self.param_groups] |
|
|
|
|
|
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v |
|
|
for k, v in self.state.items()} |
|
|
return { |
|
|
'state': packed_state, |
|
|
'param_groups': param_groups, |
|
|
} |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
r"""Loads the optimizer state. |
|
|
|
|
|
Args: |
|
|
state_dict (dict): optimizer state. Should be an object returned |
|
|
from a call to :meth:`state_dict`. |
|
|
""" |
|
|
|
|
|
state_dict = deepcopy(state_dict) |
|
|
|
|
|
groups = self.param_groups |
|
|
saved_groups = state_dict['param_groups'] |
|
|
|
|
|
if len(groups) != len(saved_groups): |
|
|
raise ValueError("loaded state dict has a different number of " |
|
|
"parameter groups") |
|
|
param_lens = (len(g['params']) for g in groups) |
|
|
saved_lens = (len(g['params']) for g in saved_groups) |
|
|
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): |
|
|
raise ValueError("loaded state dict contains a parameter group " |
|
|
"that doesn't match the size of optimizer's group") |
|
|
|
|
|
|
|
|
id_map = {old_id: p for old_id, p in |
|
|
zip(chain.from_iterable((g['params'] for g in saved_groups)), |
|
|
chain.from_iterable((g['params'] for g in groups)))} |
|
|
|
|
|
def cast(param, value, key=None): |
|
|
r"""Make a deep copy of value, casting all tensors to device of param.""" |
|
|
if isinstance(value, torch.Tensor): |
|
|
|
|
|
|
|
|
|
|
|
if (key != "step"): |
|
|
if param.is_floating_point(): |
|
|
value = value.to(param.dtype) |
|
|
value = value.to(param.device) |
|
|
return value |
|
|
elif isinstance(value, dict): |
|
|
return {k: cast(param, v, key=k) for k, v in value.items()} |
|
|
elif isinstance(value, container_abcs.Iterable): |
|
|
return type(value)(cast(param, v) for v in value) |
|
|
else: |
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
state = defaultdict(dict) |
|
|
for k, v in state_dict['state'].items(): |
|
|
if k in id_map: |
|
|
param = id_map[k] |
|
|
state[param] = cast(param, v) |
|
|
else: |
|
|
state[k] = v |
|
|
|
|
|
|
|
|
def update_group(group, new_group): |
|
|
new_group['params'] = group['params'] |
|
|
return new_group |
|
|
param_groups = [ |
|
|
update_group(g, ng) for g, ng in zip(groups, saved_groups)] |
|
|
self.__setstate__({'state': state, 'param_groups': param_groups}) |
|
|
|
|
|
def zero_grad(self, set_to_none: bool = False): |
|
|
r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero. |
|
|
|
|
|
Args: |
|
|
set_to_none (bool): instead of setting to zero, set the grads to None. |
|
|
This will in general have lower memory footprint, and can modestly improve performance. |
|
|
However, it changes certain behaviors. For example: |
|
|
1. When the user tries to access a gradient and perform manual ops on it, |
|
|
a None attribute or a Tensor full of 0s will behave differently. |
|
|
2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s |
|
|
are guaranteed to be None for params that did not receive a gradient. |
|
|
3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None |
|
|
(in one case it does the step with a gradient of 0 and in the other it skips |
|
|
the step altogether). |
|
|
""" |
|
|
foreach = self.defaults.get('foreach', False) |
|
|
|
|
|
if not hasattr(self, "_zero_grad_profile_name"): |
|
|
self._hook_for_profile() |
|
|
if foreach: |
|
|
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) |
|
|
with torch.autograd.profiler.record_function(self._zero_grad_profile_name): |
|
|
for group in self.param_groups: |
|
|
for p in group['params']: |
|
|
if p.grad is not None: |
|
|
if set_to_none: |
|
|
p.grad = None |
|
|
else: |
|
|
if p.grad.grad_fn is not None: |
|
|
p.grad.detach_() |
|
|
else: |
|
|
p.grad.requires_grad_(False) |
|
|
if (not foreach or p.grad.is_sparse): |
|
|
p.grad.zero_() |
|
|
else: |
|
|
per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) |
|
|
if foreach: |
|
|
for _, per_dtype_grads in per_device_and_dtype_grads.items(): |
|
|
for grads in per_dtype_grads.values(): |
|
|
torch._foreach_zero_(grads) |
|
|
|
|
|
def step(self, closure): |
|
|
r"""Performs a single optimization step (parameter update). |
|
|
|
|
|
Args: |
|
|
closure (Callable): A closure that reevaluates the model and |
|
|
returns the loss. Optional for most optimizers. |
|
|
|
|
|
.. note:: |
|
|
Unless otherwise specified, this function should not modify the |
|
|
``.grad`` field of the parameters. |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def add_param_group(self, param_group): |
|
|
r"""Add a param group to the :class:`Optimizer` s `param_groups`. |
|
|
|
|
|
This can be useful when fine tuning a pre-trained network as frozen layers can be made |
|
|
trainable and added to the :class:`Optimizer` as training progresses. |
|
|
|
|
|
Args: |
|
|
param_group (dict): Specifies what Tensors should be optimized along with group |
|
|
specific optimization options. |
|
|
""" |
|
|
assert isinstance(param_group, dict), "param group must be a dict" |
|
|
|
|
|
params = param_group['params'] |
|
|
if isinstance(params, torch.Tensor): |
|
|
param_group['params'] = [params] |
|
|
elif isinstance(params, set): |
|
|
raise TypeError('optimizer parameters need to be organized in ordered collections, but ' |
|
|
'the ordering of tensors in sets will change between runs. Please use a list instead.') |
|
|
else: |
|
|
param_group['params'] = list(params) |
|
|
|
|
|
for param in param_group['params']: |
|
|
if not isinstance(param, torch.Tensor): |
|
|
raise TypeError("optimizer can only optimize Tensors, " |
|
|
"but one of the params is " + torch.typename(param)) |
|
|
if not self.defaults.get('differentiable', None) and not (param.is_leaf or param.retains_grad): |
|
|
raise ValueError("can't optimize a non-leaf Tensor") |
|
|
|
|
|
for name, default in self.defaults.items(): |
|
|
if default is required and name not in param_group: |
|
|
raise ValueError("parameter group didn't specify a value of required optimization parameter " + |
|
|
name) |
|
|
else: |
|
|
param_group.setdefault(name, default) |
|
|
|
|
|
params = param_group['params'] |
|
|
if len(params) != len(set(params)): |
|
|
warnings.warn("optimizer contains a parameter group with duplicate parameters; " |
|
|
"in future, this will cause an error; " |
|
|
"see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3) |
|
|
|
|
|
param_set = set() |
|
|
for group in self.param_groups: |
|
|
param_set.update(set(group['params'])) |
|
|
|
|
|
if not param_set.isdisjoint(set(param_group['params'])): |
|
|
raise ValueError("some parameters appear in more than one parameter group") |
|
|
|
|
|
self.param_groups.append(param_group) |
|
|
|