| |
| |
| |
| |
| @@ -1237,6 +1237,10 @@ def get_optimizer_cls_and_kwargs( |
| OptimizerNames.ADAMW_8BIT, |
| OptimizerNames.PAGED_ADAMW, |
| OptimizerNames.PAGED_ADAMW_8BIT, |
| + OptimizerNames.ADEMAMIX, |
| + OptimizerNames.ADEMAMIX_8BIT, |
| + OptimizerNames.PAGED_ADEMAMIX, |
| + OptimizerNames.PAGED_ADEMAMIX_8BIT, |
| OptimizerNames.LION, |
| OptimizerNames.LION_8BIT, |
| OptimizerNames.PAGED_LION, |
| @@ -1266,6 +1270,33 @@ def get_optimizer_cls_and_kwargs( |
| # Above we pass all `adam_kwargs` to the optimizer, here |
| # we only pass `optim_args` which can be passed by the user. |
| additional_optim_kwargs = optim_args |
| + elif "ademamix" in args.optim: |
| + if is_bitsandbytes_available() and version.parse( |
| + importlib.metadata.version("bitsandbytes") |
| + ) < version.parse("0.44.0"): |
| + raise ValueError( |
| + "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. " |
| + "Please install `bitsandbytes` >= 0.44.0." |
| + ) |
| + |
| + from bitsandbytes.optim import AdEMAMix |
| + |
| + optimizer_cls = AdEMAMix |
| + additional_optim_kwargs = { |
| + "betas": ( |
| + float(optim_args.get("beta1", args.adam_beta1)), |
| + float(optim_args.get("beta2", args.adam_beta2)), |
| + float(optim_args.get("beta3", 0.9999)), |
| + ), |
| + "alpha": float(optim_args.get("alpha", 5.0)), |
| + "eps": float(optim_args.get("eps", args.adam_epsilon)), |
| + } |
| + |
| + if "t_alpha" in optim_args: |
| + additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"]) |
| + |
| + if "t_beta3" in optim_args: |
| + additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"]) |
| |
| bnb_kwargs = {"optim_bits": optim_bits} |
| if "rmsprop" not in args.optim: |
| |
| |
| |
| |
| @@ -155,14 +155,18 @@ class OptimizerNames(ExplicitEnum): |
| ADAFACTOR = "adafactor" |
| ADAMW_ANYPRECISION = "adamw_anyprecision" |
| ADAMW_TORCH_4BIT = "adamw_torch_4bit" |
| + ADEMAMIX = "ademamix" |
| SGD = "sgd" |
| ADAGRAD = "adagrad" |
| ADAMW_BNB = "adamw_bnb_8bit" |
| ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit |
| + ADEMAMIX_8BIT = "ademamix_8bit" |
| LION_8BIT = "lion_8bit" |
| LION = "lion_32bit" |
| PAGED_ADAMW = "paged_adamw_32bit" |
| PAGED_ADAMW_8BIT = "paged_adamw_8bit" |
| + PAGED_ADEMAMIX = "paged_ademamix_32bit" |
| + PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit" |
| PAGED_LION = "paged_lion_32bit" |
| PAGED_LION_8BIT = "paged_lion_8bit" |
| RMSPROP = "rmsprop" |
| @@ -618,7 +622,7 @@ class TrainingArguments: |
| "adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py) |
| for a full list of optimizers. |
| optim_args (`str`, *optional*): |
| - Optional arguments that are supplied to AnyPrecisionAdamW. |
| + Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore. |
| group_by_length (`bool`, *optional*, defaults to `False`): |
| Whether or not to group together samples of roughly the same length in the training dataset (to minimize |
| padding applied and be more efficient). Only useful if applying dynamic padding. |
| |
| |
| |
| |
| @@ -15,6 +15,7 @@ |
| |
| import dataclasses |
| import gc |
| +import importlib |
| import json |
| import math |
| import os |
| @@ -32,6 +33,7 @@ |
| |
| import numpy as np |
| from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files |
| +from packaging import version |
| from parameterized import parameterized |
| from requests.exceptions import HTTPError |
| |
| @@ -1091,6 +1093,40 @@ def test_rmsprop_bnb(self): |
| # Check that it trains without errors |
| trainer.train() |
| |
| + @require_bitsandbytes |
| + def test_ademamix_bnb(self): |
| + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) |
| + tiny_gpt2 = GPT2LMHeadModel(config) |
| + x = torch.randint(0, 100, (128,)) |
| + train_dataset = RepeatDataset(x) |
| + |
| + with tempfile.TemporaryDirectory() as tmpdir: |
| + # Trainer without inf/nan filter |
| + args = TrainingArguments( |
| + tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix" |
| + ) |
| + trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset) |
| + |
| + # Check that it trains without errors |
| + trainer.train() |
| + |
| + @require_bitsandbytes |
| + def test_ademamix_bnb_8bit(self): |
| + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) |
| + tiny_gpt2 = GPT2LMHeadModel(config) |
| + x = torch.randint(0, 100, (128,)) |
| + train_dataset = RepeatDataset(x) |
| + |
| + with tempfile.TemporaryDirectory() as tmpdir: |
| + # Trainer without inf/nan filter |
| + args = TrainingArguments( |
| + tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix_8bit" |
| + ) |
| + trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset) |
| + |
| + # Check that it trains without errors |
| + trainer.train() |
| + |
| @require_bitsandbytes |
| def test_rmsprop_bnb_8bit(self): |
| config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) |
| @@ -4187,6 +4223,13 @@ def hp_name(trial): |
| "lr": TrainingArguments.learning_rate, |
| } |
| |
| + default_ademamix_kwargs = { |
| + "betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2, 0.9999), |
| + "alpha": 5.0, |
| + "eps": TrainingArguments.adam_epsilon, |
| + "lr": TrainingArguments.learning_rate, |
| + } |
| + |
| default_anyprecision_kwargs = { |
| "use_kahan_summation": False, |
| "momentum_dtype": torch.float32, |
| @@ -4291,6 +4334,36 @@ def hp_name(trial): |
| ) |
| ) |
| |
| + if version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.44.0"): |
| + optim_test_params.append( |
| + ( |
| + TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"), |
| + bnb.optim.AdEMAMix, |
| + default_ademamix_kwargs, |
| + ) |
| + ) |
| + optim_test_params.append( |
| + ( |
| + TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"), |
| + bnb.optim.AdEMAMix, |
| + default_ademamix_kwargs, |
| + ) |
| + ) |
| + optim_test_params.append( |
| + ( |
| + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"), |
| + bnb.optim.AdEMAMix, |
| + default_ademamix_kwargs, |
| + ) |
| + ) |
| + optim_test_params.append( |
| + ( |
| + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"), |
| + bnb.optim.AdEMAMix, |
| + default_ademamix_kwargs, |
| + ) |
| + ) |
| + |
| if is_torchdistx_available(): |
| import torchdistx |
| |
| @@ -4420,6 +4493,62 @@ def test_bnb_paged_adam8bit(self): |
| default_adam_kwargs, |
| ) |
| |
| + def test_bnb_ademamix(self): |
| + mock = Mock() |
| + modules = { |
| + "bitsandbytes": mock, |
| + "bitsandbytes.optim": mock.optim, |
| + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, |
| + } |
| + with patch.dict("sys.modules", modules): |
| + self.check_optim_and_kwargs( |
| + TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"), |
| + mock.optim.AdEMAMix, |
| + default_ademamix_kwargs, |
| + ) |
| + |
| + def test_bnb_ademamix8bit(self): |
| + mock = Mock() |
| + modules = { |
| + "bitsandbytes": mock, |
| + "bitsandbytes.optim": mock.optim, |
| + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, |
| + } |
| + with patch.dict("sys.modules", modules): |
| + self.check_optim_and_kwargs( |
| + TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"), |
| + mock.optim.AdEMAMix, |
| + default_ademamix_kwargs, |
| + ) |
| + |
| + def test_bnb_paged_ademamix(self): |
| + mock = Mock() |
| + modules = { |
| + "bitsandbytes": mock, |
| + "bitsandbytes.optim": mock.optim, |
| + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, |
| + } |
| + with patch.dict("sys.modules", modules): |
| + self.check_optim_and_kwargs( |
| + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"), |
| + mock.optim.AdEMAMix, |
| + default_ademamix_kwargs, |
| + ) |
| + |
| + def test_bnb_paged_ademamix8bit(self): |
| + mock = Mock() |
| + modules = { |
| + "bitsandbytes": mock, |
| + "bitsandbytes.optim": mock.optim, |
| + "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix, |
| + } |
| + with patch.dict("sys.modules", modules): |
| + self.check_optim_and_kwargs( |
| + TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"), |
| + mock.optim.AdEMAMix, |
| + default_ademamix_kwargs, |
| + ) |
| + |
| def test_bnb_lion(self): |
| mock = Mock() |
| modules = { |
| @@ -4503,6 +4632,42 @@ def test_bnb_paged_adam8bit_no_bnb(self): |
| with self.assertRaises(ValueError): |
| Trainer.get_optimizer_cls_and_kwargs(args) |
| |
| + def test_bnb_ademamix_no_bnb(self): |
| + args = TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None") |
| + |
| + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing |
| + # bnb will fail even if `bitsandbytes` is installed. |
| + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): |
| + with self.assertRaises(ValueError): |
| + Trainer.get_optimizer_cls_and_kwargs(args) |
| + |
| + def test_bnb_ademamix8bit_no_bnb(self): |
| + args = TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None") |
| + |
| + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing |
| + # bnb will fail even if `bitsandbytes` is installed. |
| + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): |
| + with self.assertRaises(ValueError): |
| + Trainer.get_optimizer_cls_and_kwargs(args) |
| + |
| + def test_bnb_paged_ademamix_no_bnb(self): |
| + args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None") |
| + |
| + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing |
| + # bnb will fail even if `bitsandbytes` is installed. |
| + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): |
| + with self.assertRaises(ValueError): |
| + Trainer.get_optimizer_cls_and_kwargs(args) |
| + |
| + def test_bnb_paged_ademamix8bit_no_bnb(self): |
| + args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None") |
| + |
| + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing |
| + # bnb will fail even if `bitsandbytes` is installed. |
| + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): |
| + with self.assertRaises(ValueError): |
| + Trainer.get_optimizer_cls_and_kwargs(args) |
| + |
| def test_bnb_paged_lion_no_bnb(self): |
| args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None") |
| |
|
|