ll / zero_neuron.py
kaamd's picture
Upload folder using huggingface_hub
6a48e45 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class ZeroNeuron(nn.Module):
"""
Adapted from: https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py
We replace 'self.log_alpha = nn.Parameter...' with something input-dependant: 'self.log_alpha = nn.Linear(...)'
>>> import torch
>>> x = torch.rand(12, 100)
>>> module = HardConcrete(in_features=100, out_features=100)
>>> mask = module(x)
>>> norm = module.l0_norm()
"""
def __init__(self,
in_features: int,
out_features: int,
init_mean: float = 0.5,
init_std: float = 0.01,
temperature: float = 1.0,
stretch: float = 0.1,
eps: float = 1e-6) -> None:
"""Initialize the HardConcrete module.
Parameters
----------
in_features : int
The features of the input X.
out_features: int
The dimension of the sparsity (should be 1 if you want sparsity to be applied on the penultimate dimension of X)
init_mean : float, optional
Initialization value for hard concrete parameter,
by default 0.5.,
init_std: float, optional
Used to initialize the hard concrete parameters,
by default 0.01.
temperature : float, optional
Temperature used to control the sharpness of the
distribution, by default 1.0
stretch : float, optional
Stretch the sampled value from [0, 1] to the interval
[-stretch, 1 + stretch], by default 0.1.
"""
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.limit_l = -stretch
self.limit_r = 1.0 + stretch
# we use a low-rank structure to reduce the computation cost.
if self.out_features > 1:
self.log_alpha = nn.Sequential(nn.Linear(in_features, 1, bias=False), nn.Linear(1, out_features, bias=False))
else:
self.log_alpha = nn.Linear(in_features, 1, bias=False)
self.beta = temperature
self.init_mean = init_mean
self.init_std = init_std
self.bias = -self.beta * math.log(-self.limit_l / self.limit_r)
self.eps = eps
self.log_alpha.apply(self.reset_parameters)
@torch.no_grad()
def reset_parameters(self, module):
"""Reset the parameters of this module."""
mean = math.log(1 - self.init_mean) - math.log(self.init_mean)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean, self.init_std)
def l0_norm(self, x: torch.Tensor, log_alpha=None) -> torch.Tensor:
"""Compute the expected L0 norm of this mask.
Returns
-------
torch.Tensor
The expected L0 norm.
"""
log_alpha = self.log_alpha(x).squeeze(-1) if log_alpha is None else log_alpha
return (log_alpha + self.bias).sigmoid().mean()
def forward(self, x: torch.Tensor, dim=None) -> torch.Tensor: # type: ignore
"""Sample a harconcrete mask.
Returns
-------
torch.Tensor
The sampled binary mask
"""
log_alpha = self.log_alpha(x).squeeze(-1)
if self.training:
# print(self.log_alpha[0].weight)
# Sample mask dynamically
u = torch.rand_like(log_alpha).clamp(self.eps, 1 - self.eps)
s = F.sigmoid((torch.log(u / (1 - u)) + log_alpha) / self.beta)
s = s * (self.limit_r - self.limit_l) + self.limit_l
mask = s.clamp(min=0., max=1.)
else:
# TODO: use this approach when dim is specified, other wise use per-sample / per-token sparsity
if dim is not None:
expected_num_zeros = dim
else:
# Get expected sparsity
sparsity_axis = self.out_features if self.out_features != 1 else x.shape[-1]
# b, s
expected_num_zeros = sparsity_axis - (log_alpha + self.bias).sigmoid().mean().item()
num_zeros = round(expected_num_zeros)
# Approximate expected value of each mask variable z;
# We use an empirically validated magic number 0.8
soft_mask = F.sigmoid(log_alpha / self.beta * 0.8)
# Prune small values to set to 0
_, indices = torch.topk(soft_mask, k=num_zeros, largest=False)
soft_mask[..., indices] = 0.
self.compiled_mask = soft_mask
mask = self.compiled_mask
return mask
def extre_repr(self) -> str:
return f"in_features={self.in_features}, out_features={self.out_features}"
def __repr__(self) -> str:
return "{}({})".format(self.__class__.__name__, self.extre_repr())