fix init
#7
by
leejunhyeok - opened
torch-ext/optimizer/muon.py
CHANGED
|
@@ -210,7 +210,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 210 |
|
| 211 |
def __init__(
|
| 212 |
self,
|
| 213 |
-
|
| 214 |
is_muon_func=default_is_muon,
|
| 215 |
lr=1e-3,
|
| 216 |
momentum=0.95,
|
|
@@ -233,7 +233,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 233 |
none_grad=none_grad,
|
| 234 |
)
|
| 235 |
|
| 236 |
-
super().__init__(
|
| 237 |
self.is_muon_func = is_muon_func
|
| 238 |
self.model = model
|
| 239 |
|
|
|
|
| 210 |
|
| 211 |
def __init__(
|
| 212 |
self,
|
| 213 |
+
params,
|
| 214 |
is_muon_func=default_is_muon,
|
| 215 |
lr=1e-3,
|
| 216 |
momentum=0.95,
|
|
|
|
| 233 |
none_grad=none_grad,
|
| 234 |
)
|
| 235 |
|
| 236 |
+
super().__init__(params, defaults)
|
| 237 |
self.is_muon_func = is_muon_func
|
| 238 |
self.model = model
|
| 239 |
|