|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABCMeta, abstractmethod, abstractproperty |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class PytorchModuleHook(metaclass=ABCMeta): |
|
|
"""Base class for PyTorch module hook registers. |
|
|
|
|
|
An instance of a subclass of PytorchModuleHook can be used to |
|
|
register hook to a pytorch module using the `register` method like: |
|
|
hook_register.register(module) |
|
|
|
|
|
Subclasses should add/overwrite the following methods: |
|
|
- __init__ |
|
|
- hook |
|
|
- hook_type |
|
|
""" |
|
|
|
|
|
@abstractmethod |
|
|
def hook(self, *args, **kwargs): |
|
|
"""Hook function.""" |
|
|
|
|
|
@abstractproperty |
|
|
def hook_type(self) -> str: |
|
|
"""Hook type Subclasses should overwrite this function to return a |
|
|
string value in. |
|
|
|
|
|
{`forward`, `forward_pre`, `backward`} |
|
|
""" |
|
|
|
|
|
def register(self, module): |
|
|
"""Register the hook function to the module. |
|
|
|
|
|
Args: |
|
|
module (pytorch module): the module to register the hook. |
|
|
|
|
|
Returns: |
|
|
handle (torch.utils.hooks.RemovableHandle): a handle to remove |
|
|
the hook by calling handle.remove() |
|
|
""" |
|
|
assert isinstance(module, torch.nn.Module) |
|
|
|
|
|
if self.hook_type == 'forward': |
|
|
h = module.register_forward_hook(self.hook) |
|
|
elif self.hook_type == 'forward_pre': |
|
|
h = module.register_forward_pre_hook(self.hook) |
|
|
elif self.hook_type == 'backward': |
|
|
h = module.register_backward_hook(self.hook) |
|
|
else: |
|
|
raise ValueError(f'Invalid hook type {self.hook}') |
|
|
|
|
|
return h |
|
|
|
|
|
|
|
|
class WeightNormClipHook(PytorchModuleHook): |
|
|
"""Apply weight norm clip regularization. |
|
|
|
|
|
The module's parameter will be clip to a given maximum norm before each |
|
|
forward pass. |
|
|
|
|
|
Args: |
|
|
max_norm (float): The maximum norm of the parameter. |
|
|
module_param_names (str|list): The parameter name (or name list) to |
|
|
apply weight norm clip. |
|
|
""" |
|
|
|
|
|
def __init__(self, max_norm=1.0, module_param_names='weight'): |
|
|
self.module_param_names = module_param_names if isinstance( |
|
|
module_param_names, list) else [module_param_names] |
|
|
self.max_norm = max_norm |
|
|
|
|
|
@property |
|
|
def hook_type(self): |
|
|
return 'forward_pre' |
|
|
|
|
|
def hook(self, module, _input): |
|
|
for name in self.module_param_names: |
|
|
assert name in module._parameters, f'{name} is not a parameter' \ |
|
|
f' of the module {type(module)}' |
|
|
param = module._parameters[name] |
|
|
|
|
|
with torch.no_grad(): |
|
|
m = param.norm().item() |
|
|
if m > self.max_norm: |
|
|
param.mul_(self.max_norm / (m + 1e-6)) |
|
|
|