Kernels
Files changed (1) hide show
  1. torch-ext/optimizer/muon.py +2 -2
torch-ext/optimizer/muon.py CHANGED
@@ -210,7 +210,7 @@ class Muon(torch.optim.Optimizer):
210
 
211
  def __init__(
212
  self,
213
- model,
214
  is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
@@ -233,7 +233,7 @@ class Muon(torch.optim.Optimizer):
233
  none_grad=none_grad,
234
  )
235
 
236
- super().__init__(model.parameters(), defaults)
237
  self.is_muon_func = is_muon_func
238
  self.model = model
239
 
 
210
 
211
  def __init__(
212
  self,
213
+ params,
214
  is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
 
233
  none_grad=none_grad,
234
  )
235
 
236
+ super().__init__(params, defaults)
237
  self.is_muon_func = is_muon_func
238
  self.model = model
239