Muinez commited on
Commit
2df3e13
·
verified ·
1 Parent(s): 30818b6

Upload snooc.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. snooc.py +85 -0
snooc.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim import Optimizer
3
+
4
+ class SnooC(Optimizer):
5
+ """
6
+ Fixed SnooC Optimizer
7
+ """
8
+ @torch.no_grad()
9
+ def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None:
10
+ self.optimizer = optimizer
11
+ self.lr = lr
12
+ self.momentum = momentum
13
+ self.k = k
14
+ self.current_step = 0
15
+ self.model_params = None
16
+ self.outer_buf = None
17
+ self.outer_optimizer = None
18
+
19
+ if self.optimizer.param_groups:
20
+ self.param_groups = self.optimizer.param_groups
21
+
22
+ @torch.no_grad()
23
+ def _initialize_outer_optimizer(self):
24
+ # Исправленная логика сбора параметров
25
+ params = []
26
+ for pg in self.optimizer.param_groups:
27
+ for param in pg['params']:
28
+ # Собираем только тензоры, требующие градиента (обычно это то, что нужно оптимизировать)
29
+ if isinstance(param, torch.Tensor) and param.requires_grad:
30
+ params.append(param)
31
+
32
+ if not params:
33
+ return
34
+
35
+ self.model_params = list(params)
36
+ self.outer_buf = [p.clone() for p in self.model_params]
37
+
38
+ # Инициализируем внешний оптимизатор только если есть параметры
39
+ self.outer_optimizer = torch.optim.SGD(
40
+ self.model_params,
41
+ lr=self.lr,
42
+ momentum=self.momentum,
43
+ nesterov=True,
44
+ # fused=True может вызывать ошибки на некоторых версиях torch/hw,
45
+ # можно поставить False, если будет падать дальше, но пока оставим True
46
+ fused=True,
47
+ )
48
+ self.param_groups = self.optimizer.param_groups
49
+
50
+ @torch.no_grad()
51
+ def step(self, closure=None):
52
+ if self.outer_optimizer is None or self.current_step == 0:
53
+ if self.optimizer.param_groups:
54
+ self._initialize_outer_optimizer()
55
+
56
+ # Если после попытки инициализации параметры все еще None,
57
+ # значит оптимизировать нечего, просто делаем шаг базового оптимизатора
58
+ if self.model_params is None:
59
+ return self.optimizer.step(closure)
60
+
61
+ loss = self.optimizer.step(closure)
62
+
63
+ # Добавляем проверку на None здесь на всякий случай
64
+ if self.model_params is not None and self.current_step % self.k == 0:
65
+ for p_new, p_old in zip(self.model_params, self.outer_buf):
66
+ if p_new.grad is None: continue # Защита от отсутствующих градиентов
67
+ p_new.grad = p_old.data - p_new.data
68
+ p_new.copy_(p_old, non_blocking=True)
69
+
70
+ self.outer_optimizer.step()
71
+
72
+ for p_new, p_old in zip(self.model_params, self.outer_buf):
73
+ p_old.copy_(p_new, non_blocking=True)
74
+
75
+ self.current_step += 1
76
+ return loss
77
+
78
+ def zero_grad(self, set_to_none: bool = False):
79
+ self.optimizer.zero_grad(set_to_none=set_to_none)
80
+
81
+ def state_dict(self):
82
+ return self.optimizer.state_dict()
83
+
84
+ def load_state_dict(self, state_dict):
85
+ self.optimizer.load_state_dict(state_dict)