| |
| from abc import ABCMeta, abstractmethod, abstractproperty |
|
|
| import torch |
|
|
|
|
| class PytorchModuleHook(metaclass=ABCMeta): |
| """Base class for PyTorch module hook registers. |
| |
| An instance of a subclass of PytorchModuleHook can be used to |
| register hook to a pytorch module using the `register` method like: |
| hook_register.register(module) |
| |
| Subclasses should add/overwrite the following methods: |
| - __init__ |
| - hook |
| - hook_type |
| """ |
|
|
| @abstractmethod |
| def hook(self, *args, **kwargs): |
| """Hook function.""" |
|
|
| @abstractproperty |
| def hook_type(self) -> str: |
| """Hook type Subclasses should overwrite this function to return a |
| string value in. |
| |
| {`forward`, `forward_pre`, `backward`} |
| """ |
|
|
| def register(self, module): |
| """Register the hook function to the module. |
| |
| Args: |
| module (pytorch module): the module to register the hook. |
| |
| Returns: |
| handle (torch.utils.hooks.RemovableHandle): a handle to remove |
| the hook by calling handle.remove() |
| """ |
| assert isinstance(module, torch.nn.Module) |
|
|
| if self.hook_type == 'forward': |
| h = module.register_forward_hook(self.hook) |
| elif self.hook_type == 'forward_pre': |
| h = module.register_forward_pre_hook(self.hook) |
| elif self.hook_type == 'backward': |
| h = module.register_backward_hook(self.hook) |
| else: |
| raise ValueError(f'Invalid hook type {self.hook}') |
|
|
| return h |
|
|
|
|
| class WeightNormClipHook(PytorchModuleHook): |
| """Apply weight norm clip regularization. |
| |
| The module's parameter will be clip to a given maximum norm before each |
| forward pass. |
| |
| Args: |
| max_norm (float): The maximum norm of the parameter. |
| module_param_names (str|list): The parameter name (or name list) to |
| apply weight norm clip. |
| """ |
|
|
| def __init__(self, max_norm=1.0, module_param_names='weight'): |
| self.module_param_names = module_param_names if isinstance( |
| module_param_names, list) else [module_param_names] |
| self.max_norm = max_norm |
|
|
| @property |
| def hook_type(self): |
| return 'forward_pre' |
|
|
| def hook(self, module, _input): |
| for name in self.module_param_names: |
| assert name in module._parameters, f'{name} is not a parameter' \ |
| f' of the module {type(module)}' |
| param = module._parameters[name] |
|
|
| with torch.no_grad(): |
| m = param.norm().item() |
| if m > self.max_norm: |
| param.mul_(self.max_norm / (m + 1e-6)) |
|
|