|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import functools
|
| import signal
|
| import warnings
|
| from collections.abc import Callable
|
|
|
| import psutil
|
| import pytest
|
| import torch
|
| from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available
|
| from transformers.testing_utils import backend_device_count, torch_device
|
| from transformers.utils import (
|
| is_kernels_available,
|
| is_peft_available,
|
| is_rich_available,
|
| is_torch_available,
|
| is_vision_available,
|
| )
|
|
|
| from trl.import_utils import (
|
| is_jmespath_available,
|
| is_joblib_available,
|
| is_liger_kernel_available,
|
| is_math_verify_available,
|
| is_mergekit_available,
|
| is_vllm_available,
|
| )
|
|
|
|
|
| require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
|
| require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
|
| require_jmespath = pytest.mark.skipif(not is_jmespath_available(), reason="test requires jmespath")
|
| require_kernels = pytest.mark.skipif(not is_kernels_available(), reason="test requires kernels")
|
| require_liger_kernel = pytest.mark.skipif(not is_liger_kernel_available(), reason="test requires liger-kernel")
|
| require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify")
|
| require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit")
|
| require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft")
|
| require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich")
|
| require_sklearn = pytest.mark.skipif(
|
| not (is_sklearn_available() and is_joblib_available()), reason="test requires sklearn"
|
| )
|
| require_torch_accelerator = pytest.mark.skipif(
|
| torch_device is None or torch_device == "cpu", reason="test requires accelerator"
|
| )
|
| require_torch_multi_accelerator = pytest.mark.skipif(
|
| not is_torch_available() or backend_device_count(torch_device) <= 1, reason="test requires multiple accelerators"
|
| )
|
| require_vision = pytest.mark.skipif(not is_vision_available(), reason="test requires vision")
|
| require_vllm = pytest.mark.skipif(not is_vllm_available(), reason="test requires vllm")
|
| require_wandb = pytest.mark.skipif(not is_wandb_available(), reason="test requires wandb")
|
| require_no_wandb = pytest.mark.skipif(is_wandb_available(), reason="test requires no wandb")
|
| require_3_accelerators = pytest.mark.skipif(
|
| not (getattr(torch, torch_device, torch.cuda).device_count() >= 3),
|
| reason=f"test requires at least 3 {torch_device}s",
|
| )
|
|
|
|
|
| def is_bitsandbytes_multi_backend_available() -> bool:
|
| if is_bitsandbytes_available():
|
| import bitsandbytes as bnb
|
|
|
| return "multi_backend" in getattr(bnb, "features", set())
|
| return False
|
|
|
|
|
|
|
| require_torch_gpu_if_bnb_not_multi_backend_enabled = pytest.mark.skipif(
|
| not is_bitsandbytes_multi_backend_available() and not torch_device == "cuda",
|
| reason="test requires bitsandbytes multi-backend enabled or 'cuda' torch device",
|
| )
|
|
|
|
|
| def is_ampere_or_newer(device_index=0):
|
| if not torch.cuda.is_available():
|
| return False
|
|
|
| major, minor = torch.cuda.get_device_capability(device_index)
|
|
|
| return (major, minor) >= (8, 0)
|
|
|
|
|
| require_ampere_or_newer = pytest.mark.skipif(not is_ampere_or_newer(), reason="test requires Ampere or newer GPU")
|
|
|
|
|
| class TrlTestCase:
|
| @pytest.fixture(autouse=True)
|
| def set_tmp_dir(self, tmp_path):
|
| self.tmp_dir = str(tmp_path)
|
|
|
|
|
| def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> Callable:
|
| """
|
| Decorator to ignore warnings with a specific message and/or category.
|
|
|
| Args:
|
| message (`str`, *optional*):
|
| Regex pattern for the warning message to ignore. If `None`, all messages are ignored.
|
| category (`type[Warning]`, *optional*, defaults to `Warning`):
|
| Warning class to ignore. Defaults to `Warning`, which ignores all warnings.
|
| """
|
|
|
| def decorator(test_func):
|
| @functools.wraps(test_func)
|
| def wrapper(*args, **kwargs):
|
| with warnings.catch_warnings():
|
| warnings.filterwarnings("ignore", message=message, category=category)
|
| return test_func(*args, **kwargs)
|
|
|
| return wrapper
|
|
|
| return decorator
|
|
|
|
|
| def kill_process(process):
|
| parent = psutil.Process(process.pid)
|
| children = parent.children(recursive=True)
|
| for child in children:
|
| try:
|
| child.send_signal(signal.SIGTERM)
|
| child.wait(timeout=5)
|
| except psutil.TimeoutExpired:
|
| child.kill()
|
| except psutil.NoSuchProcess:
|
| pass
|
| try:
|
| process.terminate()
|
| process.wait(timeout=5)
|
| except psutil.TimeoutExpired:
|
| process.kill()
|
| except psutil.NoSuchProcess:
|
| pass
|
|
|