File size: 3,628 Bytes
2df3e13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torch.optim import Optimizer

class SnooC(Optimizer):
    """

    Fixed SnooC Optimizer

    """
    @torch.no_grad()
    def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None:
        self.optimizer = optimizer
        self.lr = lr
        self.momentum = momentum
        self.k = k
        self.current_step = 0
        self.model_params = None
        self.outer_buf = None
        self.outer_optimizer = None

        if self.optimizer.param_groups:
            self.param_groups = self.optimizer.param_groups
    
    @torch.no_grad()
    def _initialize_outer_optimizer(self):
        # Исправленная логика сбора параметров
        params = []
        for pg in self.optimizer.param_groups:
            for param in pg['params']:
                # Собираем только тензоры, требующие градиента (обычно это то, что нужно оптимизировать)
                if isinstance(param, torch.Tensor) and param.requires_grad:
                    params.append(param)
        
        if not params:
            return

        self.model_params = list(params)
        self.outer_buf = [p.clone() for p in self.model_params]
        
        # Инициализируем внешний оптимизатор только если есть параметры
        self.outer_optimizer = torch.optim.SGD(
            self.model_params,
            lr=self.lr,
            momentum=self.momentum,
            nesterov=True,
            # fused=True может вызывать ошибки на некоторых версиях torch/hw, 
            # можно поставить False, если будет падать дальше, но пока оставим True
            fused=True, 
        )
        self.param_groups = self.optimizer.param_groups

    @torch.no_grad()
    def step(self, closure=None):
        if self.outer_optimizer is None or self.current_step == 0:
            if self.optimizer.param_groups:
                self._initialize_outer_optimizer()
            
            # Если после попытки инициализации параметры все еще None, 
            # значит оптимизировать нечего, просто делаем шаг базового оптимизатора
            if self.model_params is None:
                return self.optimizer.step(closure)

        loss = self.optimizer.step(closure)
        
        # Добавляем проверку на None здесь на всякий случай
        if self.model_params is not None and self.current_step % self.k == 0:
            for p_new, p_old in zip(self.model_params, self.outer_buf):
                if p_new.grad is None: continue # Защита от отсутствующих градиентов
                p_new.grad = p_old.data - p_new.data
                p_new.copy_(p_old, non_blocking=True)

            self.outer_optimizer.step()

            for p_new, p_old in zip(self.model_params, self.outer_buf):
                p_old.copy_(p_new, non_blocking=True)
                
        self.current_step += 1
        return loss
    
    def zero_grad(self, set_to_none: bool = False):
        self.optimizer.zero_grad(set_to_none=set_to_none)

    def state_dict(self):
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)