File size: 3,330 Bytes
63fe015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from torch.optim import Optimizer

class SnooC(Optimizer):
    """
    @DominikKallusky, @vishal9-team, @vinaysrao

    Sparse Nesterov Outer Optimizer (Snoo) is a momentum-based wrapper to any optimizer that can
    improve the stability and smoothness of the optimization process and thus the quality
    of large language models (LLM) and other models. Snoo implicitly adds temporal regularization
    to the parameters, thus smoothing the training trajectory and instilling a bias towards flatter
    minima and lower parameter norms. Snoo is computationally efficient, incurring minimal overhead
    in compute and moderate memory usage.
    """

    @torch.no_grad()
    def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None:
        self.optimizer = optimizer
        self.lr = lr
        self.momentum = momentum
        self.k = k
        self.current_step = 0
        self.model_params = None
        self.outer_buf = None
        self.outer_optimizer = None

        # Check if the optimizer already has parameters
        if self.optimizer.param_groups:
            self.param_groups = self.optimizer.param_groups
    
    @torch.no_grad()
    def _initialize_outer_optimizer(self):
        params = []
        for pg in self.optimizer.param_groups:
            if len(pg['params']) > 1:
                for param in pg['params']:
                    if isinstance(param, torch.Tensor):
                        params.append(param)
            else:
                params = pg['params']
        
        if not params:
            return

        self.model_params = list(params)
        self.outer_buf = [p.clone() for p in self.model_params]
        self.outer_optimizer = torch.optim.SGD(
            self.model_params,
            lr=self.lr,
            momentum=self.momentum,
            nesterov=True,
            fused=True,
        )
        self.param_groups = self.optimizer.param_groups
        del params

    @torch.no_grad()
    def step(self, closure=None):
        if self.outer_optimizer is None or self.current_step == 0:
            # If the optimizer has been updated with parameters, initialize.
            if self.optimizer.param_groups:
                self._initialize_outer_optimizer()
            else:
                # If there are still no parameters, we cannot perform a step.
                # Depending on the use case, you might want to raise an error
                # or simply return without doing anything.
                return self.optimizer.step(closure)

        loss = self.optimizer.step(closure)
        if self.current_step % self.k == 0:
            for p_new, p_old in zip(self.model_params, self.outer_buf):
                p_new.grad = p_old.data - p_new.data
                p_new.copy_(p_old, non_blocking=True)

            self.outer_optimizer.step()

            for p_new, p_old in zip(self.model_params, self.outer_buf):
                p_old.copy_(p_new, non_blocking=True)
        self.current_step += 1
        return loss
    
    def zero_grad(self, set_to_none: bool = False):
        self.optimizer.zero_grad(set_to_none=set_to_none)

    def state_dict(self):
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)