import math import torch import torch.nn as nn import torch.nn.functional as F class ZeroNeuron(nn.Module): """ Adapted from: https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py We replace 'self.log_alpha = nn.Parameter...' with something input-dependant: 'self.log_alpha = nn.Linear(...)' >>> import torch >>> x = torch.rand(12, 100) >>> module = HardConcrete(in_features=100, out_features=100) >>> mask = module(x) >>> norm = module.l0_norm() """ def __init__(self, in_features: int, out_features: int, init_mean: float = 0.5, init_std: float = 0.01, temperature: float = 1.0, stretch: float = 0.1, eps: float = 1e-6) -> None: """Initialize the HardConcrete module. Parameters ---------- in_features : int The features of the input X. out_features: int The dimension of the sparsity (should be 1 if you want sparsity to be applied on the penultimate dimension of X) init_mean : float, optional Initialization value for hard concrete parameter, by default 0.5., init_std: float, optional Used to initialize the hard concrete parameters, by default 0.01. temperature : float, optional Temperature used to control the sharpness of the distribution, by default 1.0 stretch : float, optional Stretch the sampled value from [0, 1] to the interval [-stretch, 1 + stretch], by default 0.1. """ super().__init__() self.in_features = in_features self.out_features = out_features self.limit_l = -stretch self.limit_r = 1.0 + stretch # we use a low-rank structure to reduce the computation cost. if self.out_features > 1: self.log_alpha = nn.Sequential(nn.Linear(in_features, 1, bias=False), nn.Linear(1, out_features, bias=False)) else: self.log_alpha = nn.Linear(in_features, 1, bias=False) self.beta = temperature self.init_mean = init_mean self.init_std = init_std self.bias = -self.beta * math.log(-self.limit_l / self.limit_r) self.eps = eps self.log_alpha.apply(self.reset_parameters) @torch.no_grad() def reset_parameters(self, module): """Reset the parameters of this module.""" mean = math.log(1 - self.init_mean) - math.log(self.init_mean) if isinstance(module, nn.Linear): module.weight.data.normal_(mean, self.init_std) def l0_norm(self, x: torch.Tensor, log_alpha=None) -> torch.Tensor: """Compute the expected L0 norm of this mask. Returns ------- torch.Tensor The expected L0 norm. """ log_alpha = self.log_alpha(x).squeeze(-1) if log_alpha is None else log_alpha return (log_alpha + self.bias).sigmoid().mean() def forward(self, x: torch.Tensor, dim=None) -> torch.Tensor: # type: ignore """Sample a harconcrete mask. Returns ------- torch.Tensor The sampled binary mask """ log_alpha = self.log_alpha(x).squeeze(-1) if self.training: # print(self.log_alpha[0].weight) # Sample mask dynamically u = torch.rand_like(log_alpha).clamp(self.eps, 1 - self.eps) s = F.sigmoid((torch.log(u / (1 - u)) + log_alpha) / self.beta) s = s * (self.limit_r - self.limit_l) + self.limit_l mask = s.clamp(min=0., max=1.) else: # TODO: use this approach when dim is specified, other wise use per-sample / per-token sparsity if dim is not None: expected_num_zeros = dim else: # Get expected sparsity sparsity_axis = self.out_features if self.out_features != 1 else x.shape[-1] # b, s expected_num_zeros = sparsity_axis - (log_alpha + self.bias).sigmoid().mean().item() num_zeros = round(expected_num_zeros) # Approximate expected value of each mask variable z; # We use an empirically validated magic number 0.8 soft_mask = F.sigmoid(log_alpha / self.beta * 0.8) # Prune small values to set to 0 _, indices = torch.topk(soft_mask, k=num_zeros, largest=False) soft_mask[..., indices] = 0. self.compiled_mask = soft_mask mask = self.compiled_mask return mask def extre_repr(self) -> str: return f"in_features={self.in_features}, out_features={self.out_features}" def __repr__(self) -> str: return "{}({})".format(self.__class__.__name__, self.extre_repr())