File size: 4,912 Bytes
6a48e45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class ZeroNeuron(nn.Module):
    """
    Adapted from: https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py
    We replace 'self.log_alpha = nn.Parameter...' with something input-dependant: 'self.log_alpha = nn.Linear(...)'

    >>> import torch
    >>> x = torch.rand(12, 100)
    >>> module = HardConcrete(in_features=100, out_features=100)
    >>> mask = module(x)
    >>> norm = module.l0_norm()

    """

    def __init__(self,
                 in_features: int,
                 out_features: int,
                 init_mean: float = 0.5,
                 init_std: float = 0.01,
                 temperature: float = 1.0,
                 stretch: float = 0.1,
                 eps: float = 1e-6) -> None:
        """Initialize the HardConcrete module.

        Parameters
        ----------
        in_features : int
            The features of the input X.
        out_features: int
            The dimension of the sparsity (should be 1 if you want sparsity to be applied on the penultimate dimension of X)
        init_mean : float, optional
            Initialization value for hard concrete parameter,
            by default 0.5.,
        init_std: float, optional
            Used to initialize the hard concrete parameters,
            by default 0.01.
        temperature : float, optional
            Temperature used to control the sharpness of the
            distribution, by default 1.0
        stretch : float, optional
            Stretch the sampled value from [0, 1] to the interval
            [-stretch, 1 + stretch], by default 0.1.

        """
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.limit_l = -stretch
        self.limit_r = 1.0 + stretch
        # we use a low-rank structure to reduce the computation cost.
        if self.out_features > 1:
            self.log_alpha = nn.Sequential(nn.Linear(in_features, 1, bias=False), nn.Linear(1, out_features, bias=False))
        else:
            self.log_alpha = nn.Linear(in_features, 1, bias=False)

        self.beta = temperature
        self.init_mean = init_mean
        self.init_std = init_std
        self.bias = -self.beta * math.log(-self.limit_l / self.limit_r)

        self.eps = eps
        self.log_alpha.apply(self.reset_parameters)

    @torch.no_grad()
    def reset_parameters(self, module):
        """Reset the parameters of this module."""
        mean = math.log(1 - self.init_mean) - math.log(self.init_mean)
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean, self.init_std)

    def l0_norm(self, x: torch.Tensor, log_alpha=None) -> torch.Tensor:
        """Compute the expected L0 norm of this mask.

        Returns
        -------
        torch.Tensor
            The expected L0 norm.

        """
        log_alpha = self.log_alpha(x).squeeze(-1) if log_alpha is None else log_alpha
        return (log_alpha + self.bias).sigmoid().mean()

    def forward(self, x: torch.Tensor, dim=None) -> torch.Tensor:  # type: ignore
        """Sample a harconcrete mask.

        Returns
        -------
        torch.Tensor
            The sampled binary mask

        """
        log_alpha = self.log_alpha(x).squeeze(-1)
        
        if self.training:
            # print(self.log_alpha[0].weight)
            # Sample mask dynamically
            u = torch.rand_like(log_alpha).clamp(self.eps, 1 - self.eps)
            s = F.sigmoid((torch.log(u / (1 - u)) + log_alpha) / self.beta)
            s = s * (self.limit_r - self.limit_l) + self.limit_l
            mask = s.clamp(min=0., max=1.)

        else:
            # TODO: use this approach when dim is specified, other wise use per-sample / per-token sparsity
            if dim is not None:
                expected_num_zeros = dim
            else:
                # Get expected sparsity
                sparsity_axis = self.out_features if self.out_features != 1 else x.shape[-1]
                # b, s
                expected_num_zeros = sparsity_axis - (log_alpha + self.bias).sigmoid().mean().item()
            num_zeros = round(expected_num_zeros)
            # Approximate expected value of each mask variable z;
            # We use an empirically validated magic number 0.8
            soft_mask = F.sigmoid(log_alpha / self.beta * 0.8)
            # Prune small values to set to 0
            _, indices = torch.topk(soft_mask, k=num_zeros, largest=False)
            soft_mask[..., indices] = 0.
            self.compiled_mask = soft_mask
            mask = self.compiled_mask

        return mask

    def extre_repr(self) -> str:
        return f"in_features={self.in_features}, out_features={self.out_features}"

    def __repr__(self) -> str:
        return "{}({})".format(self.__class__.__name__, self.extre_repr())