|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import subprocess
|
| from pathlib import Path
|
|
|
| import pytest
|
| import torch
|
| import transformers
|
| from packaging.version import Version
|
|
|
| from ..testing_utils import TrlTestCase, require_torch_multi_accelerator
|
|
|
|
|
| ROOT = Path(__file__).resolve().parents[2]
|
|
|
|
|
| def run_command(command: list[str], env: dict[str, str]) -> None:
|
| result = subprocess.run(command, env=env, cwd=ROOT)
|
| assert result.returncode == 0
|
|
|
|
|
| @pytest.fixture
|
| def get_config_path(lazy_shared_datadir):
|
| def _get_config_path(config_name):
|
| return lazy_shared_datadir / "accelerate_configs" / f"{config_name}.yaml"
|
|
|
| return _get_config_path
|
|
|
|
|
| @require_torch_multi_accelerator
|
| class TestDistributed(TrlTestCase):
|
| @pytest.mark.parametrize(
|
| "config",
|
| [
|
| "ddp",
|
| pytest.param(
|
| "zero2",
|
| marks=pytest.mark.xfail(
|
| Version(transformers.__version__) == Version("5.1.0"),
|
| reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)",
|
| ),
|
| ),
|
| pytest.param(
|
| "zero3",
|
| marks=pytest.mark.xfail(
|
| Version(transformers.__version__) == Version("5.1.0"),
|
| reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)",
|
| ),
|
| ),
|
| "fsdp2",
|
| ],
|
| )
|
| def test_sft(self, config, get_config_path):
|
|
|
| run_command(
|
| [
|
| "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/sft.py",
|
| "--output_dir", self.tmp_dir,
|
| "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| "--dataset_name", "trl-internal-testing/zen",
|
| "--dataset_config", "standard_language_modeling",
|
| ],
|
| os.environ.copy(),
|
| )
|
|
|
|
|
| @pytest.mark.parametrize(
|
| "config",
|
| [
|
| "ddp",
|
| pytest.param(
|
| "zero2",
|
| marks=pytest.mark.xfail(
|
| Version(transformers.__version__) == Version("5.1.0"),
|
| reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)",
|
| ),
|
| ),
|
| pytest.param(
|
| "zero3",
|
| marks=pytest.mark.xfail(
|
| Version(transformers.__version__) == Version("5.1.0"),
|
| reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)",
|
| ),
|
| ),
|
| "fsdp2",
|
| ],
|
| )
|
| def test_dpo(self, config, get_config_path):
|
|
|
| run_command(
|
| [
|
| "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/dpo.py",
|
| "--output_dir", self.tmp_dir,
|
| "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| "--dataset_name", "trl-internal-testing/zen",
|
| "--dataset_config", "standard_preference",
|
| ],
|
| os.environ.copy(),
|
| )
|
|
|
|
|
| @pytest.mark.parametrize(
|
| "config",
|
| [
|
| "ddp",
|
| pytest.param(
|
| "zero2",
|
| marks=pytest.mark.xfail(
|
| Version(transformers.__version__) == Version("5.1.0"),
|
| reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)",
|
| ),
|
| ),
|
| pytest.param(
|
| "zero3",
|
| marks=pytest.mark.xfail(
|
| Version(transformers.__version__) == Version("5.1.0"),
|
| reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)",
|
| ),
|
| ),
|
| "fsdp2",
|
| ],
|
| )
|
| def test_sft_dataset_streaming(self, config, get_config_path):
|
|
|
| run_command(
|
| [
|
| "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/sft.py",
|
| "--output_dir", self.tmp_dir,
|
| "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| "--dataset_name", "trl-internal-testing/zen",
|
| "--dataset_config", "standard_language_modeling",
|
| "--dataset_streaming",
|
| "--max_steps", "3",
|
| ],
|
| os.environ.copy(),
|
| )
|
|
|
|
|
| @pytest.mark.parametrize(
|
| "config",
|
| [
|
| "ddp",
|
| pytest.param(
|
| "zero2",
|
| marks=pytest.mark.xfail(
|
| condition=Version("2.10") <= Version(torch.__version__)
|
| and Version(transformers.__version__) < Version("5.1.0"),
|
| reason="ZeRO 2 + PEFT was failing before transformers 5.1.0 on torch 2.10; see #4884",
|
| ),
|
| ),
|
| pytest.param(
|
| "zero3",
|
| marks=pytest.mark.xfail(
|
| condition=Version("2.10") <= Version(torch.__version__)
|
| and Version(transformers.__version__) < Version("5.1.0"),
|
| reason="ZeRO 3 + PEFT was failing before transformers 5.1.0 on torch 2.10; see #4884",
|
| ),
|
| ),
|
| "fsdp2",
|
| ],
|
| )
|
| def test_sft_peft(self, config, get_config_path):
|
|
|
| run_command(
|
| [
|
| "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/sft.py",
|
| "--output_dir", self.tmp_dir,
|
| "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| "--dataset_name", "trl-internal-testing/zen",
|
| "--dataset_config", "standard_language_modeling",
|
| "--use_peft",
|
| ],
|
| os.environ.copy(),
|
| )
|
|
|
|
|
| @pytest.mark.parametrize(
|
| "config",
|
| [
|
| "ddp",
|
| pytest.param(
|
| "zero2",
|
| marks=pytest.mark.xfail(
|
| Version(transformers.__version__) == Version("5.1.0"),
|
| reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)",
|
| ),
|
| ),
|
| pytest.param(
|
| "zero3",
|
| marks=pytest.mark.xfail(
|
| Version(transformers.__version__) == Version("5.1.0"),
|
| reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)",
|
| ),
|
| ),
|
| "fsdp2",
|
| ],
|
| )
|
| def test_reward(self, config, get_config_path):
|
|
|
| run_command(
|
| [
|
| "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/reward.py",
|
| "--output_dir", self.tmp_dir,
|
| "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| "--dataset_name", "trl-internal-testing/zen",
|
| "--dataset_config", "conversational_implicit_prompt_preference",
|
| ],
|
| os.environ.copy(),
|
| )
|
|
|
|
|
| @pytest.mark.parametrize(
|
| "config",
|
| [
|
| "ddp",
|
| pytest.param(
|
| "zero2",
|
| marks=pytest.mark.xfail(
|
| Version(transformers.__version__) == Version("5.1.0"),
|
| reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)",
|
| ),
|
| ),
|
| pytest.param(
|
| "zero3",
|
| marks=pytest.mark.xfail(
|
| Version("5.0.0") <= Version(transformers.__version__) < Version("5.5.4"),
|
| reason="ZeRO-3 fails with transformers >= 5.0.0 and < 5.5.4 (fixed in transformers#45414), see #4899",
|
| strict=True,
|
| ),
|
| ),
|
| pytest.param(
|
| "fsdp2",
|
| marks=pytest.mark.skipif(
|
| Version("5.4.0") <= Version(transformers.__version__) < Version("5.6.0"),
|
| reason="Upstream issue: NaN weights on non-rank-0 FSDP processes (see #5386 and transformers#45050)",
|
| ),
|
| ),
|
| ],
|
| )
|
| def test_rloo(self, config, get_config_path):
|
|
|
| run_command(
|
| [
|
| "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/rloo.py",
|
| "--output_dir", self.tmp_dir,
|
| "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| "--dataset_name", "trl-internal-testing/zen",
|
| "--dataset_config", "conversational_prompt_only",
|
| "--reward_model_name_or_path", "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
| ],
|
| os.environ.copy(),
|
| )
|
|
|
|
|
| @pytest.mark.parametrize(
|
| "config",
|
| [
|
| "ddp",
|
| pytest.param(
|
| "zero2",
|
| marks=pytest.mark.xfail(
|
| Version(transformers.__version__) == Version("5.1.0"),
|
| reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)",
|
| ),
|
| ),
|
| pytest.param(
|
| "zero3",
|
| marks=pytest.mark.xfail(
|
| Version("5.0.0") <= Version(transformers.__version__) < Version("5.5.4"),
|
| reason="ZeRO-3 fails with transformers >= 5.0.0 and < 5.5.4 (fixed in transformers#45414), see #4899",
|
| strict=True,
|
| ),
|
| ),
|
| "fsdp2",
|
| ],
|
| )
|
| def test_grpo(self, config, get_config_path):
|
|
|
| run_command(
|
| [
|
| "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/grpo.py",
|
| "--output_dir", self.tmp_dir,
|
| "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| "--dataset_name", "trl-internal-testing/zen",
|
| "--dataset_config", "conversational_prompt_only",
|
| "--reward_model_name_or_path", "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
| ],
|
| os.environ.copy(),
|
| )
|
|
|
|
|