Kernels
ca1207 commited on
Commit
6e9baad
·
1 Parent(s): 2a8631f

use inpalce op in update_g

Browse files
Files changed (1) hide show
  1. torch-ext/optimizer/muon.py +9 -12
torch-ext/optimizer/muon.py CHANGED
@@ -650,15 +650,12 @@ class Muon(torch.optim.Optimizer):
650
  def _update_g(self, p, g, group, momentum):
651
  # calc update
652
  state = self.state[p]
653
- if "momentum_buffer" not in state:
654
- state["momentum_buffer"] = torch.zeros_like(g)
655
- buf = state["momentum_buffer"]
656
- buf.mul_(momentum).add_(g)
657
  if group["nesterov"]:
658
- g = g.add(buf, alpha=momentum)
659
- else:
660
- g = buf
661
- return g
662
 
663
  @staticmethod
664
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -704,10 +701,10 @@ class Muon(torch.optim.Optimizer):
704
  new_scale = math.sqrt(threshold / v_ele)
705
  if new_scale < scales_full[head_idx]:
706
  scales_full[head_idx] = new_scale
707
- logger.info(
708
- f"[{kind}] Head {head_idx} exceeded threshold "
709
- f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
710
- )
711
  scaling += 1
712
 
713
  return scales_full if scaling > 0 else None
 
650
  def _update_g(self, p, g, group, momentum):
651
  # calc update
652
  state = self.state[p]
653
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
654
+ torch.add(g, buf, alpha=momentum, out=buf)
 
 
655
  if group["nesterov"]:
656
+ g.add_(buf, alpha=momentum)
657
+ return g
658
+ return buf
 
659
 
660
  @staticmethod
661
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
701
  new_scale = math.sqrt(threshold / v_ele)
702
  if new_scale < scales_full[head_idx]:
703
  scales_full[head_idx] = new_scale
704
+ #logger.info(
705
+ # f"[{kind}] Head {head_idx} exceeded threshold "
706
+ # f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
707
+ #)
708
  scaling += 1
709
 
710
  return scales_full if scaling > 0 else None