File size: 12,687 Bytes
dfefe0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index e0a49ee5795e..216d5cd42960 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -1237,6 +1237,10 @@ def get_optimizer_cls_and_kwargs(
             OptimizerNames.ADAMW_8BIT,
             OptimizerNames.PAGED_ADAMW,
             OptimizerNames.PAGED_ADAMW_8BIT,
+            OptimizerNames.ADEMAMIX,
+            OptimizerNames.ADEMAMIX_8BIT,
+            OptimizerNames.PAGED_ADEMAMIX,
+            OptimizerNames.PAGED_ADEMAMIX_8BIT,
             OptimizerNames.LION,
             OptimizerNames.LION_8BIT,
             OptimizerNames.PAGED_LION,
@@ -1266,6 +1270,33 @@ def get_optimizer_cls_and_kwargs(
                     # Above we pass all `adam_kwargs` to the optimizer, here
                     # we only pass `optim_args` which can be passed by the user.
                     additional_optim_kwargs = optim_args
+                elif "ademamix" in args.optim:
+                    if is_bitsandbytes_available() and version.parse(
+                        importlib.metadata.version("bitsandbytes")
+                    ) < version.parse("0.44.0"):
+                        raise ValueError(
+                            "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. "
+                            "Please install `bitsandbytes` >= 0.44.0."
+                        )
+
+                    from bitsandbytes.optim import AdEMAMix
+
+                    optimizer_cls = AdEMAMix
+                    additional_optim_kwargs = {
+                        "betas": (
+                            float(optim_args.get("beta1", args.adam_beta1)),
+                            float(optim_args.get("beta2", args.adam_beta2)),
+                            float(optim_args.get("beta3", 0.9999)),
+                        ),
+                        "alpha": float(optim_args.get("alpha", 5.0)),
+                        "eps": float(optim_args.get("eps", args.adam_epsilon)),
+                    }
+
+                    if "t_alpha" in optim_args:
+                        additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"])
+
+                    if "t_beta3" in optim_args:
+                        additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"])
 
                 bnb_kwargs = {"optim_bits": optim_bits}
                 if "rmsprop" not in args.optim:
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index 02413c285832..596917928350 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -155,14 +155,18 @@ class OptimizerNames(ExplicitEnum):
     ADAFACTOR = "adafactor"
     ADAMW_ANYPRECISION = "adamw_anyprecision"
     ADAMW_TORCH_4BIT = "adamw_torch_4bit"
+    ADEMAMIX = "ademamix"
     SGD = "sgd"
     ADAGRAD = "adagrad"
     ADAMW_BNB = "adamw_bnb_8bit"
     ADAMW_8BIT = "adamw_8bit"  # just an alias for adamw_bnb_8bit
+    ADEMAMIX_8BIT = "ademamix_8bit"
     LION_8BIT = "lion_8bit"
     LION = "lion_32bit"
     PAGED_ADAMW = "paged_adamw_32bit"
     PAGED_ADAMW_8BIT = "paged_adamw_8bit"
+    PAGED_ADEMAMIX = "paged_ademamix_32bit"
+    PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit"
     PAGED_LION = "paged_lion_32bit"
     PAGED_LION_8BIT = "paged_lion_8bit"
     RMSPROP = "rmsprop"
@@ -618,7 +622,7 @@ class TrainingArguments:
             "adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
             for a full list of optimizers.
         optim_args (`str`, *optional*):
-            Optional arguments that are supplied to AnyPrecisionAdamW.
+            Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore.
         group_by_length (`bool`, *optional*, defaults to `False`):
             Whether or not to group together samples of roughly the same length in the training dataset (to minimize
             padding applied and be more efficient). Only useful if applying dynamic padding.
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index 14014e4a0947..0035ff7de8ba 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -15,6 +15,7 @@
 
 import dataclasses
 import gc
+import importlib
 import json
 import math
 import os
@@ -32,6 +33,7 @@
 
 import numpy as np
 from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files
+from packaging import version
 from parameterized import parameterized
 from requests.exceptions import HTTPError
 
@@ -1091,6 +1093,40 @@ def test_rmsprop_bnb(self):
             # Check that it trains without errors
             trainer.train()
 
+    @require_bitsandbytes
+    def test_ademamix_bnb(self):
+        config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
+        tiny_gpt2 = GPT2LMHeadModel(config)
+        x = torch.randint(0, 100, (128,))
+        train_dataset = RepeatDataset(x)
+
+        with tempfile.TemporaryDirectory() as tmpdir:
+            # Trainer without inf/nan filter
+            args = TrainingArguments(
+                tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix"
+            )
+            trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
+
+            # Check that it trains without errors
+            trainer.train()
+
+    @require_bitsandbytes
+    def test_ademamix_bnb_8bit(self):
+        config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
+        tiny_gpt2 = GPT2LMHeadModel(config)
+        x = torch.randint(0, 100, (128,))
+        train_dataset = RepeatDataset(x)
+
+        with tempfile.TemporaryDirectory() as tmpdir:
+            # Trainer without inf/nan filter
+            args = TrainingArguments(
+                tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix_8bit"
+            )
+            trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
+
+            # Check that it trains without errors
+            trainer.train()
+
     @require_bitsandbytes
     def test_rmsprop_bnb_8bit(self):
         config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
@@ -4187,6 +4223,13 @@ def hp_name(trial):
         "lr": TrainingArguments.learning_rate,
     }
 
+    default_ademamix_kwargs = {
+        "betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2, 0.9999),
+        "alpha": 5.0,
+        "eps": TrainingArguments.adam_epsilon,
+        "lr": TrainingArguments.learning_rate,
+    }
+
     default_anyprecision_kwargs = {
         "use_kahan_summation": False,
         "momentum_dtype": torch.float32,
@@ -4291,6 +4334,36 @@ def hp_name(trial):
             )
         )
 
+        if version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.44.0"):
+            optim_test_params.append(
+                (
+                    TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
+                    bnb.optim.AdEMAMix,
+                    default_ademamix_kwargs,
+                )
+            )
+            optim_test_params.append(
+                (
+                    TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
+                    bnb.optim.AdEMAMix,
+                    default_ademamix_kwargs,
+                )
+            )
+            optim_test_params.append(
+                (
+                    TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
+                    bnb.optim.AdEMAMix,
+                    default_ademamix_kwargs,
+                )
+            )
+            optim_test_params.append(
+                (
+                    TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
+                    bnb.optim.AdEMAMix,
+                    default_ademamix_kwargs,
+                )
+            )
+
     if is_torchdistx_available():
         import torchdistx
 
@@ -4420,6 +4493,62 @@ def test_bnb_paged_adam8bit(self):
                 default_adam_kwargs,
             )
 
+    def test_bnb_ademamix(self):
+        mock = Mock()
+        modules = {
+            "bitsandbytes": mock,
+            "bitsandbytes.optim": mock.optim,
+            "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
+        }
+        with patch.dict("sys.modules", modules):
+            self.check_optim_and_kwargs(
+                TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
+                mock.optim.AdEMAMix,
+                default_ademamix_kwargs,
+            )
+
+    def test_bnb_ademamix8bit(self):
+        mock = Mock()
+        modules = {
+            "bitsandbytes": mock,
+            "bitsandbytes.optim": mock.optim,
+            "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
+        }
+        with patch.dict("sys.modules", modules):
+            self.check_optim_and_kwargs(
+                TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
+                mock.optim.AdEMAMix,
+                default_ademamix_kwargs,
+            )
+
+    def test_bnb_paged_ademamix(self):
+        mock = Mock()
+        modules = {
+            "bitsandbytes": mock,
+            "bitsandbytes.optim": mock.optim,
+            "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
+        }
+        with patch.dict("sys.modules", modules):
+            self.check_optim_and_kwargs(
+                TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
+                mock.optim.AdEMAMix,
+                default_ademamix_kwargs,
+            )
+
+    def test_bnb_paged_ademamix8bit(self):
+        mock = Mock()
+        modules = {
+            "bitsandbytes": mock,
+            "bitsandbytes.optim": mock.optim,
+            "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
+        }
+        with patch.dict("sys.modules", modules):
+            self.check_optim_and_kwargs(
+                TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
+                mock.optim.AdEMAMix,
+                default_ademamix_kwargs,
+            )
+
     def test_bnb_lion(self):
         mock = Mock()
         modules = {
@@ -4503,6 +4632,42 @@ def test_bnb_paged_adam8bit_no_bnb(self):
             with self.assertRaises(ValueError):
                 Trainer.get_optimizer_cls_and_kwargs(args)
 
+    def test_bnb_ademamix_no_bnb(self):
+        args = TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None")
+
+        # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
+        # bnb will fail even if `bitsandbytes` is installed.
+        with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
+            with self.assertRaises(ValueError):
+                Trainer.get_optimizer_cls_and_kwargs(args)
+
+    def test_bnb_ademamix8bit_no_bnb(self):
+        args = TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None")
+
+        # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
+        # bnb will fail even if `bitsandbytes` is installed.
+        with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
+            with self.assertRaises(ValueError):
+                Trainer.get_optimizer_cls_and_kwargs(args)
+
+    def test_bnb_paged_ademamix_no_bnb(self):
+        args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None")
+
+        # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
+        # bnb will fail even if `bitsandbytes` is installed.
+        with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
+            with self.assertRaises(ValueError):
+                Trainer.get_optimizer_cls_and_kwargs(args)
+
+    def test_bnb_paged_ademamix8bit_no_bnb(self):
+        args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None")
+
+        # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
+        # bnb will fail even if `bitsandbytes` is installed.
+        with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
+            with self.assertRaises(ValueError):
+                Trainer.get_optimizer_cls_and_kwargs(args)
+
     def test_bnb_paged_lion_no_bnb(self):
         args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")