diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e0a49ee5795e..216d5cd42960 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1237,6 +1237,10 @@ def get_optimizer_cls_and_kwargs( OptimizerNames.ADAMW_8BIT, OptimizerNames.PAGED_ADAMW, OptimizerNames.PAGED_ADAMW_8BIT, + OptimizerNames.ADEMAMIX, + OptimizerNames.ADEMAMIX_8BIT, + OptimizerNames.PAGED_ADEMAMIX, + OptimizerNames.PAGED_ADEMAMIX_8BIT, OptimizerNames.LION, OptimizerNames.LION_8BIT, OptimizerNames.PAGED_LION, @@ -1266,6 +1270,33 @@ def get_optimizer_cls_and_kwargs( # Above we pass all `adam_kwargs` to the optimizer, here # we only pass `optim_args` which can be passed by the user. additional_optim_kwargs = optim_args + elif "ademamix" in args.optim: + if is_bitsandbytes_available() and version.parse( + importlib.metadata.version("bitsandbytes") + ) < version.parse("0.44.0"): + raise ValueError( + "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. " + "Please install `bitsandbytes` >= 0.44.0." + ) + + from bitsandbytes.optim import AdEMAMix + + optimizer_cls = AdEMAMix + additional_optim_kwargs = { + "betas": ( + float(optim_args.get("beta1", args.adam_beta1)), + float(optim_args.get("beta2", args.adam_beta2)), + float(optim_args.get("beta3", 0.9999)), + ), + "alpha": float(optim_args.get("alpha", 5.0)), + "eps": float(optim_args.get("eps", args.adam_epsilon)), + } + + if "t_alpha" in optim_args: + additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"]) + + if "t_beta3" in optim_args: + additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"]) bnb_kwargs = {"optim_bits": optim_bits} if "rmsprop" not in args.optim: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 02413c285832..596917928350 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -155,14 +155,18 @@ class OptimizerNames(ExplicitEnum): ADAFACTOR = "adafactor" ADAMW_ANYPRECISION = "adamw_anyprecision" ADAMW_TORCH_4BIT = "adamw_torch_4bit" + ADEMAMIX = "ademamix" SGD = "sgd" ADAGRAD = "adagrad" ADAMW_BNB = "adamw_bnb_8bit" ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit + ADEMAMIX_8BIT = "ademamix_8bit" LION_8BIT = "lion_8bit" LION = "lion_32bit" PAGED_ADAMW = "paged_adamw_32bit" PAGED_ADAMW_8BIT = "paged_adamw_8bit" + PAGED_ADEMAMIX = "paged_ademamix_32bit" + PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit" PAGED_LION = "paged_lion_32bit" PAGED_LION_8BIT = "paged_lion_8bit" RMSPROP = "rmsprop" @@ -618,7 +622,7 @@ class TrainingArguments: "adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py) for a full list of optimizers. optim_args (`str`, *optional*): - Optional arguments that are supplied to AnyPrecisionAdamW. + Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore. group_by_length (`bool`, *optional*, defaults to `False`): Whether or not to group together samples of roughly the same length in the training dataset (to minimize padding applied and be more efficient). Only useful if applying dynamic padding. diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 14014e4a0947..0035ff7de8ba 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -15,6 +15,7 @@ import dataclasses import gc +import importlib import json import math import os @@ -32,6 +33,7 @@ import numpy as np from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files +from packaging import version from parameterized import parameterized from requests.exceptions import HTTPError @@ -1091,6 +1093,40 @@ def test_rmsprop_bnb(self): # Check that it trains without errors trainer.train() + @require_bitsandbytes + def test_ademamix_bnb(self): + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix" + ) + trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset) + + # Check that it trains without errors + trainer.train() + + @require_bitsandbytes + def test_ademamix_bnb_8bit(self): + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix_8bit" + ) + trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset) + + # Check that it trains without errors + trainer.train() + @require_bitsandbytes def test_rmsprop_bnb_8bit(self): config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) @@ -4187,6 +4223,13 @@ def hp_name(trial): "lr": TrainingArguments.learning_rate, } + default_ademamix_kwargs = { + "betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2, 0.9999), + "alpha": 5.0, + "eps": TrainingArguments.adam_epsilon, + "lr": TrainingArguments.learning_rate, + } + default_anyprecision_kwargs = { "use_kahan_summation": False, "momentum_dtype": torch.float32, @@ -4291,6 +4334,36 @@ def hp_name(trial): ) ) + if version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.44.0"): + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"), + bnb.optim.AdEMAMix, + default_ademamix_kwargs, + ) + ) + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"), + bnb.optim.AdEMAMix, + default_ademamix_kwargs, + ) + ) + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"), + bnb.optim.AdEMAMix, + default_ademamix_kwargs, + ) + ) + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"), + bnb.optim.AdEMAMix, + default_ademamix_kwargs, + ) + ) + if is_torchdistx_available(): import torchdistx @@ -4420,6 +4493,62 @@ def test_bnb_paged_adam8bit(self): default_adam_kwargs, ) + def test_bnb_ademamix(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"), + mock.optim.AdEMAMix, + default_ademamix_kwargs, + ) + + def test_bnb_ademamix8bit(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"), + mock.optim.AdEMAMix, + default_ademamix_kwargs, + ) + + def test_bnb_paged_ademamix(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"), + mock.optim.AdEMAMix, + default_ademamix_kwargs, + ) + + def test_bnb_paged_ademamix8bit(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"), + mock.optim.AdEMAMix, + default_ademamix_kwargs, + ) + def test_bnb_lion(self): mock = Mock() modules = { @@ -4503,6 +4632,42 @@ def test_bnb_paged_adam8bit_no_bnb(self): with self.assertRaises(ValueError): Trainer.get_optimizer_cls_and_kwargs(args) + def test_bnb_ademamix_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if `bitsandbytes` is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + + def test_bnb_ademamix8bit_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if `bitsandbytes` is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + + def test_bnb_paged_ademamix_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if `bitsandbytes` is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + + def test_bnb_paged_ademamix8bit_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if `bitsandbytes` is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + def test_bnb_paged_lion_no_bnb(self): args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")