File size: 3,957 Bytes
ad5f26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# mypy: allow-untyped-defs
import warnings

from .base_scheduler import BaseScheduler


__all__ = ["CubicSL"]


def _clamp(x, lo, hi):
    return max(lo, min(hi, x))


class CubicSL(BaseScheduler):
    r"""Sets the sparsity level of each parameter group to the final sl

    plus a given exponential function.



    .. math::



        s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3



    where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final

    sparsity level, :math:`f(i)` is the function to be applied to the current epoch

    :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`.

    :math:`\Delta t` is used to control how often the update of the sparsity level

    happens. By default,



    Args:

        sparsifier (BaseSparsifier): Wrapped sparsifier.

        init_sl (int, list): Initial level of sparsity

        init_t (int, list): Initial step, when pruning starts

        delta_t (int, list): Pruning frequency

        total_t (int, list): Total number of pruning steps

        initially_zero (bool, list): If True, sets the level of sparsity to 0

            before init_t (:math:`t_0`). Otherwise, the sparsity level before

            init_t (:math:`t_0`) is set to init_sl(:math:`s_0`)

        last_epoch (int): The index of last epoch. Default: -1.

        verbose (bool): If ``True``, prints a message to stdout for

            each update. Default: ``False``.

    """

    def __init__(

        self,

        sparsifier,

        init_sl=0.0,

        init_t=0,

        delta_t=10,

        total_t=100,

        initially_zero=False,

        last_epoch=-1,

        verbose=False,

    ):
        self.sparsifier = sparsifier

        self.init_sl = self._make_sure_a_list(init_sl)
        self.init_t = self._make_sure_a_list(init_t)
        self.delta_t = self._make_sure_a_list(delta_t)
        self.total_t = self._make_sure_a_list(total_t)

        self.initially_zero = self._make_sure_a_list(initially_zero)

        super().__init__(sparsifier, last_epoch, verbose)

    @staticmethod
    def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False):
        r""" "Computes the current level of sparsity.



        Based on https://arxiv.org/pdf/1710.01878.pdf



        Args:

            s_0: Initial level of sparsity, :math:`s_i`

            s_f: Target level of sparsity, :math:`s_f`

            t: Current step, :math:`t`

            t_0: Initial step, :math:`t_0`

            dt: Pruning frequency, :math:`\Delta T`

            n: Pruning steps, :math:`n`

            initially_zero: Sets the level of sparsity to 0 before t_0.

                If False, sets to s_0



        Returns:

            The sparsity level :math:`s_t` at the current step :math:`t`

        """
        if initially_zero and t < t_0:
            return 0
        s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3
        s_t = _clamp(s_t, s_0, s_f)
        return s_t

    def get_sl(self):
        if not self._get_sl_called_within_step:
            warnings.warn(
                "To get the last sparsity level computed by the scheduler, "
                "please use `get_last_sl()`."
            )
        return [
            self.sparsity_compute_fn(
                s_0=initial_sparsity,
                s_f=final_sparsity,
                t=self.last_epoch,
                t_0=initial_epoch,
                dt=delta_epoch,
                n=interval_epochs,
                initially_zero=initially_zero,
            )
            for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in zip(
                self.init_sl,
                self.base_sl,
                self.init_t,
                self.delta_t,
                self.total_t,
                self.initially_zero,
            )
        ]