use inpalce op in update_g
Browse files- torch-ext/optimizer/muon.py +9 -12
torch-ext/optimizer/muon.py
CHANGED
|
@@ -650,15 +650,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 650 |
def _update_g(self, p, g, group, momentum):
|
| 651 |
# calc update
|
| 652 |
state = self.state[p]
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
buf = state["momentum_buffer"]
|
| 656 |
-
buf.mul_(momentum).add_(g)
|
| 657 |
if group["nesterov"]:
|
| 658 |
-
g
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
return g
|
| 662 |
|
| 663 |
@staticmethod
|
| 664 |
def _update_p(p, u, lr, adjusted_lr, weight_decay):
|
|
@@ -704,10 +701,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 704 |
new_scale = math.sqrt(threshold / v_ele)
|
| 705 |
if new_scale < scales_full[head_idx]:
|
| 706 |
scales_full[head_idx] = new_scale
|
| 707 |
-
logger.info(
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
)
|
| 711 |
scaling += 1
|
| 712 |
|
| 713 |
return scales_full if scaling > 0 else None
|
|
|
|
| 650 |
def _update_g(self, p, g, group, momentum):
|
| 651 |
# calc update
|
| 652 |
state = self.state[p]
|
| 653 |
+
buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
|
| 654 |
+
torch.add(g, buf, alpha=momentum, out=buf)
|
|
|
|
|
|
|
| 655 |
if group["nesterov"]:
|
| 656 |
+
g.add_(buf, alpha=momentum)
|
| 657 |
+
return g
|
| 658 |
+
return buf
|
|
|
|
| 659 |
|
| 660 |
@staticmethod
|
| 661 |
def _update_p(p, u, lr, adjusted_lr, weight_decay):
|
|
|
|
| 701 |
new_scale = math.sqrt(threshold / v_ele)
|
| 702 |
if new_scale < scales_full[head_idx]:
|
| 703 |
scales_full[head_idx] = new_scale
|
| 704 |
+
#logger.info(
|
| 705 |
+
# f"[{kind}] Head {head_idx} exceeded threshold "
|
| 706 |
+
# f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
|
| 707 |
+
#)
|
| 708 |
scaling += 1
|
| 709 |
|
| 710 |
return scales_full if scaling > 0 else None
|