| | from typing import Any, Dict, Optional |
| |
|
| | from torch import nn |
| |
|
| | __all__ = [ |
| | "module_to_fqn", |
| | "fqn_to_module", |
| | "get_arg_info_from_tensor_fqn", |
| | "FakeSparsity", |
| | ] |
| |
|
| |
|
| | def module_to_fqn(model: nn.Module, module: nn.Module, prefix: str = "") -> Optional[str]: |
| | """ |
| | Returns the fqn for a module or None if module not a descendent of model. |
| | """ |
| | if module is model: |
| | return "" |
| | for name, child in model.named_children(): |
| | fqn = module_to_fqn(child, module, ".") |
| | if isinstance(fqn, str): |
| | return prefix + name + fqn |
| | return None |
| |
|
| |
|
| | def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]: |
| | """ |
| | Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path` |
| | doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors. |
| | """ |
| | if path != "": |
| | for name in path.split("."): |
| | model = getattr(model, name, None) |
| | return model |
| |
|
| |
|
| | def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]: |
| | """ |
| | Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name |
| | """ |
| | |
| | |
| | |
| | tensor_name = tensor_fqn.split(".")[-1] |
| | module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] |
| |
|
| | module = fqn_to_module(model, module_fqn) |
| |
|
| | return { |
| | "module_fqn": module_fqn, |
| | "module": module, |
| | "tensor_name": tensor_name, |
| | "tensor_fqn": tensor_fqn, |
| | } |
| |
|
| |
|
| | |
| | class FakeSparsity(nn.Module): |
| | r"""Parametrization for the weights. Should be attached to the 'weight' or |
| | any other parmeter that requires a mask applied to it. |
| | |
| | Note:: |
| | |
| | Once the mask is passed, the variable should not change the id. The |
| | contents of the mask can change, but the mask reference itself should |
| | not. |
| | """ |
| | def __init__(self, mask): |
| | super().__init__() |
| | self.register_buffer("mask", mask) |
| |
|
| | def forward(self, x): |
| | assert self.mask.shape == x.shape |
| | return self.mask * x |
| |
|
| | def state_dict(self, *args, **kwargs): |
| | |
| | |
| | |
| | return {} |
| |
|