| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class ZeroNeuron(nn.Module): |
| | """ |
| | Adapted from: https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py |
| | We replace 'self.log_alpha = nn.Parameter...' with something input-dependant: 'self.log_alpha = nn.Linear(...)' |
| | |
| | >>> import torch |
| | >>> x = torch.rand(12, 100) |
| | >>> module = HardConcrete(in_features=100, out_features=100) |
| | >>> mask = module(x) |
| | >>> norm = module.l0_norm() |
| | |
| | """ |
| |
|
| | def __init__(self, |
| | in_features: int, |
| | out_features: int, |
| | init_mean: float = 0.5, |
| | init_std: float = 0.01, |
| | temperature: float = 1.0, |
| | stretch: float = 0.1, |
| | eps: float = 1e-6) -> None: |
| | """Initialize the HardConcrete module. |
| | |
| | Parameters |
| | ---------- |
| | in_features : int |
| | The features of the input X. |
| | out_features: int |
| | The dimension of the sparsity (should be 1 if you want sparsity to be applied on the penultimate dimension of X) |
| | init_mean : float, optional |
| | Initialization value for hard concrete parameter, |
| | by default 0.5., |
| | init_std: float, optional |
| | Used to initialize the hard concrete parameters, |
| | by default 0.01. |
| | temperature : float, optional |
| | Temperature used to control the sharpness of the |
| | distribution, by default 1.0 |
| | stretch : float, optional |
| | Stretch the sampled value from [0, 1] to the interval |
| | [-stretch, 1 + stretch], by default 0.1. |
| | |
| | """ |
| | super().__init__() |
| |
|
| | self.in_features = in_features |
| | self.out_features = out_features |
| | self.limit_l = -stretch |
| | self.limit_r = 1.0 + stretch |
| | |
| | if self.out_features > 1: |
| | self.log_alpha = nn.Sequential(nn.Linear(in_features, 1, bias=False), nn.Linear(1, out_features, bias=False)) |
| | else: |
| | self.log_alpha = nn.Linear(in_features, 1, bias=False) |
| |
|
| | self.beta = temperature |
| | self.init_mean = init_mean |
| | self.init_std = init_std |
| | self.bias = -self.beta * math.log(-self.limit_l / self.limit_r) |
| |
|
| | self.eps = eps |
| | self.log_alpha.apply(self.reset_parameters) |
| |
|
| | @torch.no_grad() |
| | def reset_parameters(self, module): |
| | """Reset the parameters of this module.""" |
| | mean = math.log(1 - self.init_mean) - math.log(self.init_mean) |
| | if isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean, self.init_std) |
| |
|
| | def l0_norm(self, x: torch.Tensor, log_alpha=None) -> torch.Tensor: |
| | """Compute the expected L0 norm of this mask. |
| | |
| | Returns |
| | ------- |
| | torch.Tensor |
| | The expected L0 norm. |
| | |
| | """ |
| | log_alpha = self.log_alpha(x).squeeze(-1) if log_alpha is None else log_alpha |
| | return (log_alpha + self.bias).sigmoid().mean() |
| |
|
| | def forward(self, x: torch.Tensor, dim=None) -> torch.Tensor: |
| | """Sample a harconcrete mask. |
| | |
| | Returns |
| | ------- |
| | torch.Tensor |
| | The sampled binary mask |
| | |
| | """ |
| | log_alpha = self.log_alpha(x).squeeze(-1) |
| | |
| | if self.training: |
| | |
| | |
| | u = torch.rand_like(log_alpha).clamp(self.eps, 1 - self.eps) |
| | s = F.sigmoid((torch.log(u / (1 - u)) + log_alpha) / self.beta) |
| | s = s * (self.limit_r - self.limit_l) + self.limit_l |
| | mask = s.clamp(min=0., max=1.) |
| |
|
| | else: |
| | |
| | if dim is not None: |
| | expected_num_zeros = dim |
| | else: |
| | |
| | sparsity_axis = self.out_features if self.out_features != 1 else x.shape[-1] |
| | |
| | expected_num_zeros = sparsity_axis - (log_alpha + self.bias).sigmoid().mean().item() |
| | num_zeros = round(expected_num_zeros) |
| | |
| | |
| | soft_mask = F.sigmoid(log_alpha / self.beta * 0.8) |
| | |
| | _, indices = torch.topk(soft_mask, k=num_zeros, largest=False) |
| | soft_mask[..., indices] = 0. |
| | self.compiled_mask = soft_mask |
| | mask = self.compiled_mask |
| |
|
| | return mask |
| |
|
| | def extre_repr(self) -> str: |
| | return f"in_features={self.in_features}, out_features={self.out_features}" |
| |
|
| | def __repr__(self) -> str: |
| | return "{}({})".format(self.__class__.__name__, self.extre_repr()) |