Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from abc import ABCMeta, abstractmethod, abstractproperty | |
| import torch | |
| class PytorchModuleHook(metaclass=ABCMeta): | |
| """Base class for PyTorch module hook registers. | |
| An instance of a subclass of PytorchModuleHook can be used to | |
| register hook to a pytorch module using the `register` method like: | |
| hook_register.register(module) | |
| Subclasses should add/overwrite the following methods: | |
| - __init__ | |
| - hook | |
| - hook_type | |
| """ | |
| def hook(self, *args, **kwargs): | |
| """Hook function.""" | |
| def hook_type(self) -> str: | |
| """Hook type Subclasses should overwrite this function to return a | |
| string value in. | |
| {`forward`, `forward_pre`, `backward`} | |
| """ | |
| def register(self, module): | |
| """Register the hook function to the module. | |
| Args: | |
| module (pytorch module): the module to register the hook. | |
| Returns: | |
| handle (torch.utils.hooks.RemovableHandle): a handle to remove | |
| the hook by calling handle.remove() | |
| """ | |
| assert isinstance(module, torch.nn.Module) | |
| if self.hook_type == 'forward': | |
| h = module.register_forward_hook(self.hook) | |
| elif self.hook_type == 'forward_pre': | |
| h = module.register_forward_pre_hook(self.hook) | |
| elif self.hook_type == 'backward': | |
| h = module.register_backward_hook(self.hook) | |
| else: | |
| raise ValueError(f'Invalid hook type {self.hook}') | |
| return h | |
| class WeightNormClipHook(PytorchModuleHook): | |
| """Apply weight norm clip regularization. | |
| The module's parameter will be clip to a given maximum norm before each | |
| forward pass. | |
| Args: | |
| max_norm (float): The maximum norm of the parameter. | |
| module_param_names (str|list): The parameter name (or name list) to | |
| apply weight norm clip. | |
| """ | |
| def __init__(self, max_norm=1.0, module_param_names='weight'): | |
| self.module_param_names = module_param_names if isinstance( | |
| module_param_names, list) else [module_param_names] | |
| self.max_norm = max_norm | |
| def hook_type(self): | |
| return 'forward_pre' | |
| def hook(self, module, _input): | |
| for name in self.module_param_names: | |
| assert name in module._parameters, f'{name} is not a parameter' \ | |
| f' of the module {type(module)}' | |
| param = module._parameters[name] | |
| with torch.no_grad(): | |
| m = param.norm().item() | |
| if m > self.max_norm: | |
| param.mul_(self.max_norm / (m + 1e-6)) | |