| from abc import ABC, abstractmethod |
| from typing import ClassVar, Literal, Protocol, runtime_checkable, Type |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| class Parametrization(nn.Module, ABC): |
| """ |
| Abstract base class for parametrizations. |
| A parametrization can be injected into any torch module of type `base_class` by `parametrize_module`. |
| A parametrized module will follow the `ParametrizedModule` interface. |
| |
| This will overload the weight, bias, and forward of the module so that they play together with |
| the parametrization. The external behavior of the parametrized module remains unchanged, for instance, |
| a parametrized `Linear` module will still work as expected. |
| |
| Attributes: |
| base_class: The base class of the module that can be parametrized. |
| initialized: A flag that indicates whether the parametrization has been initialized. |
| """ |
|
|
| initialized: bool = False |
| base_class: ClassVar[Type[nn.Module]] |
|
|
| def initialize(self, base_module: "Parametrization.base_class") -> None: |
| self._initialize(base_module) |
| self.initialized = True |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute the forward pass of the parametrization. |
| This is particularly important when a standard forward pass based on `weight` would be inefficient. |
| """ |
| assert self.initialized |
| x = self._forward(x) |
| return x |
|
|
| @property |
| def weight(self) -> torch.Tensor: |
| """Compute the weight tensor of the parametrization.""" |
| return self._weight() |
|
|
| @property |
| def bias(self) -> torch.Tensor | None: |
| """Compute the bias tensor of the parametrization.""" |
| return self._bias() |
|
|
| @abstractmethod |
| def _forward(self, x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def _initialize(self, base_module: "Parametrization.base_class") -> None: |
| """ |
| Initialize the parametrization based on a given base module. |
| This method should build the internal representation the module's weight and bias, |
| registering all required buffers and parameters in `self`. |
| """ |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def _weight(self) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def _bias(self) -> torch.Tensor | None: |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def get_target_params(self) -> dict[str, torch.nn.Parameter]: |
| """ |
| Return the (tunable) target parameters of the parametrization. |
| Here, "target parameters" means that they can be tuned and potentially compressed |
| by `self.reset_target_params(mode="compress")`. |
| Other torch parameters of the module could be tuned as well, but should not returned here. |
| The returned dictionary should be compatible with `self.named_parameters()`. |
| |
| See Also: |
| - `ParametrizedModel.get_target_params` |
| - `ParametrizedModel.compress` |
| """ |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def reset_target_params(self, mode: Literal["full", "nonzero", "compress"] = "full") -> None: |
| """ |
| Reset the target parameters of the parametrization according to a given mode. |
| |
| Args: |
| mode: The reset mode. |
| "full" means reset to original value at initialization. |
| "nonzero" means reset all non-zero values to original value at initialization. |
| "compress" means the all zero values are removed and the the parameters are compressed accordingly. |
| """ |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def get_num_params(self, compressed: bool = False, target_params: dict[str, torch.Tensor] | None = None) -> int: |
| """ |
| Computes the (effective) number of parameters of the parametrization. |
| |
| Args: |
| compressed: Whether to count the number of parameters as if the module was actually compressed. |
| If `False`, the number of parameters is the same as in the original module. |
| target_params: Count the number of parameters as if `target_params` were used instead of |
| `self.get_target_params()`. This "what if" feature is important when pruning |
| a full `ParametrizedModel` to a certain target ratio. |
| """ |
| raise NotImplementedError |
|
|
|
|
| @runtime_checkable |
| class ParametrizedModule(Protocol): |
| """ |
| Interface for a parametrized `nn.Module`. |
| It ensures that `weight` and `bias` are forwarded to the `Parametrization` instance. |
| |
| Attributes: |
| parametrization: The `Parametrization` instance of the module. |
| _forward: The original forward function of the module. |
| __old_class__: The original class of the module. |
| |
| Notes: |
| `_forward` and `__old_class__` are used by `parametrize_module` and `unparametrize_module` |
| to allow restoring the original behavior of the module. |
| """ |
|
|
| parametrization: Parametrization |
| _forward: callable |
| __old_class__: type[nn.Module] |
|
|
| @property |
| def weight(self): |
| return self.parametrization.weight |
|
|
| @property |
| def bias(self): |
| return self.parametrization.bias |
|
|
|
|
| def parametrize_module(module: nn.Module, parametrization: Parametrization) -> ParametrizedModule and nn.Module: |
| """ |
| Parametrize a module using a `Parametrization` instance. |
| |
| Args: |
| module: The module to be parametrized. |
| parametrization: The `Parametrization` instance to be applied to the module. |
| |
| Returns: The parametrized module using the `ParametrizedModule` interface. |
| |
| Notes: |
| Adopted from https://stackoverflow.com/a/31075641 |
| """ |
|
|
| assert isinstance(module, parametrization.base_class) |
| module.__old_class__ = module.__class__ |
|
|
| |
| module.add_module("parametrization", parametrization) |
| module.parametrization.initialize(module) |
|
|
| |
| module._forward = module.forward |
|
|
| |
| del module.weight |
| del module.bias |
| module.__class__ = type("Parametrized" + module.__class__.__name__, (module.__class__, ParametrizedModule), {}) |
| |
| module.forward = module.parametrization.forward |
|
|
| return module |
|
|
|
|
| def unparametrize_module(module: ParametrizedModule) -> nn.Module: |
| """ |
| Revert the parametrization of a module. |
| |
| Args: |
| module: A module that has been parametrized by `parametrize_module`. |
| |
| Returns: The original module. |
| |
| Notes: |
| Adopted from https://stackoverflow.com/a/31075641 |
| """ |
|
|
| |
| weight = module.weight |
| bias = module.bias |
|
|
| assert isinstance(module, nn.Module) |
|
|
| |
| module.__class__ = type(module.__old_class__.__name__, (module.__old_class__,), {}) |
| delattr(module, "__old_class__") |
|
|
| |
| module.register_parameter("weight", nn.Parameter(weight, weight.requires_grad)) |
| if bias is not None: |
| module.register_parameter("bias", nn.Parameter(bias, bias.requires_grad)) |
| else: |
| module.register_parameter("bias", None) |
|
|
| |
| del module.parametrization |
| module.forward = module._forward |
| delattr(module, "_forward") |
|
|
| return module |
|
|