Spaces:
Build error
Build error
| from collections import defaultdict | |
| from enum import Enum | |
| from typing import cast, Iterable, Tuple, Union | |
| import torch | |
| from captum._utils.common import _format_tensor_into_tuples, _register_backward_hook | |
| from torch import Tensor | |
| from torch.nn import Module | |
| def _reset_sample_grads(module: Module): | |
| module.weight.sample_grad = 0 # type: ignore | |
| if module.bias is not None: | |
| module.bias.sample_grad = 0 # type: ignore | |
| def linear_param_grads( | |
| module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False | |
| ) -> None: | |
| r""" | |
| Computes parameter gradients per sample for nn.Linear module, given module | |
| input activations and output gradients. | |
| Gradients are accumulated in the sample_grad attribute of each parameter | |
| (weight and bias). If reset = True, any current sample_grad values are reset, | |
| otherwise computed gradients are accumulated and added to the existing | |
| stored gradients. | |
| Inputs with more than 2 dimensions are only supported with torch 1.8 or later | |
| """ | |
| if reset: | |
| _reset_sample_grads(module) | |
| module.weight.sample_grad += torch.einsum( # type: ignore | |
| "n...i,n...j->nij", gradient_out, activation | |
| ) | |
| if module.bias is not None: | |
| module.bias.sample_grad += torch.einsum( # type: ignore | |
| "n...i->ni", gradient_out | |
| ) | |
| def conv2d_param_grads( | |
| module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False | |
| ) -> None: | |
| r""" | |
| Computes parameter gradients per sample for nn.Conv2d module, given module | |
| input activations and output gradients. | |
| nn.Conv2d modules with padding set to a string option ('same' or 'valid') are | |
| currently unsupported. | |
| Gradients are accumulated in the sample_grad attribute of each parameter | |
| (weight and bias). If reset = True, any current sample_grad values are reset, | |
| otherwise computed gradients are accumulated and added to the existing | |
| stored gradients. | |
| """ | |
| if reset: | |
| _reset_sample_grads(module) | |
| batch_size = cast(int, activation.shape[0]) | |
| unfolded_act = torch.nn.functional.unfold( | |
| activation, | |
| cast(Union[int, Tuple[int, ...]], module.kernel_size), | |
| dilation=cast(Union[int, Tuple[int, ...]], module.dilation), | |
| padding=cast(Union[int, Tuple[int, ...]], module.padding), | |
| stride=cast(Union[int, Tuple[int, ...]], module.stride), | |
| ) | |
| reshaped_grad = gradient_out.reshape(batch_size, -1, unfolded_act.shape[-1]) | |
| grad1 = torch.einsum("ijk,ilk->ijl", reshaped_grad, unfolded_act) | |
| shape = [batch_size] + list(cast(Iterable[int], module.weight.shape)) | |
| module.weight.sample_grad += grad1.reshape(shape) # type: ignore | |
| if module.bias is not None: | |
| module.bias.sample_grad += torch.sum(reshaped_grad, dim=2) # type: ignore | |
| SUPPORTED_MODULES = { | |
| torch.nn.Conv2d: conv2d_param_grads, | |
| torch.nn.Linear: linear_param_grads, | |
| } | |
| class LossMode(Enum): | |
| SUM = 0 | |
| MEAN = 1 | |
| class SampleGradientWrapper: | |
| r""" | |
| Wrapper which allows computing sample-wise gradients in a single backward pass. | |
| This is accomplished by adding hooks to capture activations and output | |
| gradients for supported modules, and using these activations and gradients | |
| to compute the parameter gradients per-sample. | |
| Currently, only nn.Linear and nn.Conv2d modules are supported. | |
| Similar reference implementations of sample-based gradients include: | |
| - https://github.com/cybertronai/autograd-hacks | |
| - https://github.com/pytorch/opacus/tree/main/opacus/grad_sample | |
| """ | |
| def __init__(self, model): | |
| self.model = model | |
| self.hooks_added = False | |
| self.activation_dict = defaultdict(list) | |
| self.gradient_dict = defaultdict(list) | |
| self.forward_hooks = [] | |
| self.backward_hooks = [] | |
| def add_hooks(self): | |
| self.hooks_added = True | |
| self.model.apply(self._register_module_hooks) | |
| def _register_module_hooks(self, module: torch.nn.Module): | |
| if isinstance(module, tuple(SUPPORTED_MODULES.keys())): | |
| self.forward_hooks.append( | |
| module.register_forward_hook(self._forward_hook_fn) | |
| ) | |
| self.backward_hooks.append( | |
| _register_backward_hook(module, self._backward_hook_fn, None) | |
| ) | |
| def _forward_hook_fn( | |
| self, | |
| module: Module, | |
| module_input: Union[Tensor, Tuple[Tensor, ...]], | |
| module_output: Union[Tensor, Tuple[Tensor, ...]], | |
| ): | |
| inp_tuple = _format_tensor_into_tuples(module_input) | |
| self.activation_dict[module].append(inp_tuple[0].clone().detach()) | |
| def _backward_hook_fn( | |
| self, | |
| module: Module, | |
| grad_input: Union[Tensor, Tuple[Tensor, ...]], | |
| grad_output: Union[Tensor, Tuple[Tensor, ...]], | |
| ): | |
| grad_output_tuple = _format_tensor_into_tuples(grad_output) | |
| self.gradient_dict[module].append(grad_output_tuple[0].clone().detach()) | |
| def remove_hooks(self): | |
| self.hooks_added = False | |
| for hook in self.forward_hooks: | |
| hook.remove() | |
| for hook in self.backward_hooks: | |
| hook.remove() | |
| self.forward_hooks = [] | |
| self.backward_hooks = [] | |
| def _reset(self): | |
| self.activation_dict = defaultdict(list) | |
| self.gradient_dict = defaultdict(list) | |
| def compute_param_sample_gradients(self, loss_blob, loss_mode="mean"): | |
| assert ( | |
| loss_mode.upper() in LossMode.__members__ | |
| ), f"Provided loss mode {loss_mode} is not valid" | |
| mode = LossMode[loss_mode.upper()] | |
| self.model.zero_grad() | |
| loss_blob.backward(gradient=torch.ones_like(loss_blob)) | |
| for module in self.gradient_dict: | |
| sample_grad_fn = SUPPORTED_MODULES[type(module)] | |
| activations = self.activation_dict[module] | |
| gradients = self.gradient_dict[module] | |
| assert len(activations) == len(gradients), ( | |
| "Number of saved activations do not match number of saved gradients." | |
| " This may occur if multiple forward passes are run without calling" | |
| " reset or computing param gradients." | |
| ) | |
| # Reversing grads since when a module is used multiple times, | |
| # the activations will be aligned with the reverse order of the gradients, | |
| # since the order is reversed in backprop. | |
| for i, (act, grad) in enumerate( | |
| zip(activations, list(reversed(gradients))) | |
| ): | |
| mult = 1 if mode is LossMode.SUM else act.shape[0] | |
| sample_grad_fn(module, act, grad * mult, reset=(i == 0)) | |
| self._reset() | |