harness / diffs /33682.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index e0a49ee5795e..216d5cd42960 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -1237,6 +1237,10 @@ def get_optimizer_cls_and_kwargs(
OptimizerNames.ADAMW_8BIT,
OptimizerNames.PAGED_ADAMW,
OptimizerNames.PAGED_ADAMW_8BIT,
+ OptimizerNames.ADEMAMIX,
+ OptimizerNames.ADEMAMIX_8BIT,
+ OptimizerNames.PAGED_ADEMAMIX,
+ OptimizerNames.PAGED_ADEMAMIX_8BIT,
OptimizerNames.LION,
OptimizerNames.LION_8BIT,
OptimizerNames.PAGED_LION,
@@ -1266,6 +1270,33 @@ def get_optimizer_cls_and_kwargs(
# Above we pass all `adam_kwargs` to the optimizer, here
# we only pass `optim_args` which can be passed by the user.
additional_optim_kwargs = optim_args
+ elif "ademamix" in args.optim:
+ if is_bitsandbytes_available() and version.parse(
+ importlib.metadata.version("bitsandbytes")
+ ) < version.parse("0.44.0"):
+ raise ValueError(
+ "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. "
+ "Please install `bitsandbytes` >= 0.44.0."
+ )
+
+ from bitsandbytes.optim import AdEMAMix
+
+ optimizer_cls = AdEMAMix
+ additional_optim_kwargs = {
+ "betas": (
+ float(optim_args.get("beta1", args.adam_beta1)),
+ float(optim_args.get("beta2", args.adam_beta2)),
+ float(optim_args.get("beta3", 0.9999)),
+ ),
+ "alpha": float(optim_args.get("alpha", 5.0)),
+ "eps": float(optim_args.get("eps", args.adam_epsilon)),
+ }
+
+ if "t_alpha" in optim_args:
+ additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"])
+
+ if "t_beta3" in optim_args:
+ additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"])
bnb_kwargs = {"optim_bits": optim_bits}
if "rmsprop" not in args.optim:
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index 02413c285832..596917928350 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -155,14 +155,18 @@ class OptimizerNames(ExplicitEnum):
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
+ ADEMAMIX = "ademamix"
SGD = "sgd"
ADAGRAD = "adagrad"
ADAMW_BNB = "adamw_bnb_8bit"
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
+ ADEMAMIX_8BIT = "ademamix_8bit"
LION_8BIT = "lion_8bit"
LION = "lion_32bit"
PAGED_ADAMW = "paged_adamw_32bit"
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
+ PAGED_ADEMAMIX = "paged_ademamix_32bit"
+ PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit"
PAGED_LION = "paged_lion_32bit"
PAGED_LION_8BIT = "paged_lion_8bit"
RMSPROP = "rmsprop"
@@ -618,7 +622,7 @@ class TrainingArguments:
"adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
for a full list of optimizers.
optim_args (`str`, *optional*):
- Optional arguments that are supplied to AnyPrecisionAdamW.
+ Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore.
group_by_length (`bool`, *optional*, defaults to `False`):
Whether or not to group together samples of roughly the same length in the training dataset (to minimize
padding applied and be more efficient). Only useful if applying dynamic padding.
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index 14014e4a0947..0035ff7de8ba 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -15,6 +15,7 @@
import dataclasses
import gc
+import importlib
import json
import math
import os
@@ -32,6 +33,7 @@
import numpy as np
from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files
+from packaging import version
from parameterized import parameterized
from requests.exceptions import HTTPError
@@ -1091,6 +1093,40 @@ def test_rmsprop_bnb(self):
# Check that it trains without errors
trainer.train()
+ @require_bitsandbytes
+ def test_ademamix_bnb(self):
+ config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
+ tiny_gpt2 = GPT2LMHeadModel(config)
+ x = torch.randint(0, 100, (128,))
+ train_dataset = RepeatDataset(x)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Trainer without inf/nan filter
+ args = TrainingArguments(
+ tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix"
+ )
+ trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
+
+ # Check that it trains without errors
+ trainer.train()
+
+ @require_bitsandbytes
+ def test_ademamix_bnb_8bit(self):
+ config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
+ tiny_gpt2 = GPT2LMHeadModel(config)
+ x = torch.randint(0, 100, (128,))
+ train_dataset = RepeatDataset(x)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Trainer without inf/nan filter
+ args = TrainingArguments(
+ tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix_8bit"
+ )
+ trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
+
+ # Check that it trains without errors
+ trainer.train()
+
@require_bitsandbytes
def test_rmsprop_bnb_8bit(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
@@ -4187,6 +4223,13 @@ def hp_name(trial):
"lr": TrainingArguments.learning_rate,
}
+ default_ademamix_kwargs = {
+ "betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2, 0.9999),
+ "alpha": 5.0,
+ "eps": TrainingArguments.adam_epsilon,
+ "lr": TrainingArguments.learning_rate,
+ }
+
default_anyprecision_kwargs = {
"use_kahan_summation": False,
"momentum_dtype": torch.float32,
@@ -4291,6 +4334,36 @@ def hp_name(trial):
)
)
+ if version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.44.0"):
+ optim_test_params.append(
+ (
+ TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
+ bnb.optim.AdEMAMix,
+ default_ademamix_kwargs,
+ )
+ )
+ optim_test_params.append(
+ (
+ TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
+ bnb.optim.AdEMAMix,
+ default_ademamix_kwargs,
+ )
+ )
+ optim_test_params.append(
+ (
+ TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
+ bnb.optim.AdEMAMix,
+ default_ademamix_kwargs,
+ )
+ )
+ optim_test_params.append(
+ (
+ TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
+ bnb.optim.AdEMAMix,
+ default_ademamix_kwargs,
+ )
+ )
+
if is_torchdistx_available():
import torchdistx
@@ -4420,6 +4493,62 @@ def test_bnb_paged_adam8bit(self):
default_adam_kwargs,
)
+ def test_bnb_ademamix(self):
+ mock = Mock()
+ modules = {
+ "bitsandbytes": mock,
+ "bitsandbytes.optim": mock.optim,
+ "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
+ }
+ with patch.dict("sys.modules", modules):
+ self.check_optim_and_kwargs(
+ TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
+ mock.optim.AdEMAMix,
+ default_ademamix_kwargs,
+ )
+
+ def test_bnb_ademamix8bit(self):
+ mock = Mock()
+ modules = {
+ "bitsandbytes": mock,
+ "bitsandbytes.optim": mock.optim,
+ "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
+ }
+ with patch.dict("sys.modules", modules):
+ self.check_optim_and_kwargs(
+ TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
+ mock.optim.AdEMAMix,
+ default_ademamix_kwargs,
+ )
+
+ def test_bnb_paged_ademamix(self):
+ mock = Mock()
+ modules = {
+ "bitsandbytes": mock,
+ "bitsandbytes.optim": mock.optim,
+ "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
+ }
+ with patch.dict("sys.modules", modules):
+ self.check_optim_and_kwargs(
+ TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
+ mock.optim.AdEMAMix,
+ default_ademamix_kwargs,
+ )
+
+ def test_bnb_paged_ademamix8bit(self):
+ mock = Mock()
+ modules = {
+ "bitsandbytes": mock,
+ "bitsandbytes.optim": mock.optim,
+ "bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
+ }
+ with patch.dict("sys.modules", modules):
+ self.check_optim_and_kwargs(
+ TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
+ mock.optim.AdEMAMix,
+ default_ademamix_kwargs,
+ )
+
def test_bnb_lion(self):
mock = Mock()
modules = {
@@ -4503,6 +4632,42 @@ def test_bnb_paged_adam8bit_no_bnb(self):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
+ def test_bnb_ademamix_no_bnb(self):
+ args = TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None")
+
+ # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
+ # bnb will fail even if `bitsandbytes` is installed.
+ with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
+ with self.assertRaises(ValueError):
+ Trainer.get_optimizer_cls_and_kwargs(args)
+
+ def test_bnb_ademamix8bit_no_bnb(self):
+ args = TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None")
+
+ # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
+ # bnb will fail even if `bitsandbytes` is installed.
+ with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
+ with self.assertRaises(ValueError):
+ Trainer.get_optimizer_cls_and_kwargs(args)
+
+ def test_bnb_paged_ademamix_no_bnb(self):
+ args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None")
+
+ # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
+ # bnb will fail even if `bitsandbytes` is installed.
+ with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
+ with self.assertRaises(ValueError):
+ Trainer.get_optimizer_cls_and_kwargs(args)
+
+ def test_bnb_paged_ademamix8bit_no_bnb(self):
+ args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None")
+
+ # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
+ # bnb will fail even if `bitsandbytes` is installed.
+ with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
+ with self.assertRaises(ValueError):
+ Trainer.get_optimizer_cls_and_kwargs(args)
+
def test_bnb_paged_lion_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")