File size: 4,912 Bytes
6a48e45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class ZeroNeuron(nn.Module):
"""
Adapted from: https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py
We replace 'self.log_alpha = nn.Parameter...' with something input-dependant: 'self.log_alpha = nn.Linear(...)'
>>> import torch
>>> x = torch.rand(12, 100)
>>> module = HardConcrete(in_features=100, out_features=100)
>>> mask = module(x)
>>> norm = module.l0_norm()
"""
def __init__(self,
in_features: int,
out_features: int,
init_mean: float = 0.5,
init_std: float = 0.01,
temperature: float = 1.0,
stretch: float = 0.1,
eps: float = 1e-6) -> None:
"""Initialize the HardConcrete module.
Parameters
----------
in_features : int
The features of the input X.
out_features: int
The dimension of the sparsity (should be 1 if you want sparsity to be applied on the penultimate dimension of X)
init_mean : float, optional
Initialization value for hard concrete parameter,
by default 0.5.,
init_std: float, optional
Used to initialize the hard concrete parameters,
by default 0.01.
temperature : float, optional
Temperature used to control the sharpness of the
distribution, by default 1.0
stretch : float, optional
Stretch the sampled value from [0, 1] to the interval
[-stretch, 1 + stretch], by default 0.1.
"""
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.limit_l = -stretch
self.limit_r = 1.0 + stretch
# we use a low-rank structure to reduce the computation cost.
if self.out_features > 1:
self.log_alpha = nn.Sequential(nn.Linear(in_features, 1, bias=False), nn.Linear(1, out_features, bias=False))
else:
self.log_alpha = nn.Linear(in_features, 1, bias=False)
self.beta = temperature
self.init_mean = init_mean
self.init_std = init_std
self.bias = -self.beta * math.log(-self.limit_l / self.limit_r)
self.eps = eps
self.log_alpha.apply(self.reset_parameters)
@torch.no_grad()
def reset_parameters(self, module):
"""Reset the parameters of this module."""
mean = math.log(1 - self.init_mean) - math.log(self.init_mean)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean, self.init_std)
def l0_norm(self, x: torch.Tensor, log_alpha=None) -> torch.Tensor:
"""Compute the expected L0 norm of this mask.
Returns
-------
torch.Tensor
The expected L0 norm.
"""
log_alpha = self.log_alpha(x).squeeze(-1) if log_alpha is None else log_alpha
return (log_alpha + self.bias).sigmoid().mean()
def forward(self, x: torch.Tensor, dim=None) -> torch.Tensor: # type: ignore
"""Sample a harconcrete mask.
Returns
-------
torch.Tensor
The sampled binary mask
"""
log_alpha = self.log_alpha(x).squeeze(-1)
if self.training:
# print(self.log_alpha[0].weight)
# Sample mask dynamically
u = torch.rand_like(log_alpha).clamp(self.eps, 1 - self.eps)
s = F.sigmoid((torch.log(u / (1 - u)) + log_alpha) / self.beta)
s = s * (self.limit_r - self.limit_l) + self.limit_l
mask = s.clamp(min=0., max=1.)
else:
# TODO: use this approach when dim is specified, other wise use per-sample / per-token sparsity
if dim is not None:
expected_num_zeros = dim
else:
# Get expected sparsity
sparsity_axis = self.out_features if self.out_features != 1 else x.shape[-1]
# b, s
expected_num_zeros = sparsity_axis - (log_alpha + self.bias).sigmoid().mean().item()
num_zeros = round(expected_num_zeros)
# Approximate expected value of each mask variable z;
# We use an empirically validated magic number 0.8
soft_mask = F.sigmoid(log_alpha / self.beta * 0.8)
# Prune small values to set to 0
_, indices = torch.topk(soft_mask, k=num_zeros, largest=False)
soft_mask[..., indices] = 0.
self.compiled_mask = soft_mask
mask = self.compiled_mask
return mask
def extre_repr(self) -> str:
return f"in_features={self.in_features}, out_features={self.out_features}"
def __repr__(self) -> str:
return "{}({})".format(self.__class__.__name__, self.extre_repr()) |