| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """LlamaFactory test configuration. |
| | |
| | Contains shared fixtures, pytest configuration, and custom markers. |
| | """ |
| |
|
| | import os |
| |
|
| | import pytest |
| | import torch |
| | import torch.distributed as dist |
| | from pytest import Config, FixtureRequest, Item, MonkeyPatch |
| |
|
| | from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled |
| | from llamafactory.extras.packages import is_transformers_version_greater_than |
| | from llamafactory.train.test_utils import patch_valuehead_model |
| |
|
| |
|
| | CURRENT_DEVICE = get_current_device().type |
| |
|
| |
|
| | def pytest_configure(config: Config): |
| | """Register custom pytest markers.""" |
| | config.addinivalue_line( |
| | "markers", |
| | "slow: marks tests as slow (deselect with '-m \"not slow\"' or set RUN_SLOW=1 to run)", |
| | ) |
| | config.addinivalue_line( |
| | "markers", |
| | "runs_on: test requires specific device type, e.g., @pytest.mark.runs_on(['cuda'])", |
| | ) |
| | config.addinivalue_line( |
| | "markers", |
| | "require_distributed(num_devices): allow multi-device execution (default: 2)", |
| | ) |
| |
|
| |
|
| | def _handle_runs_on(items: list[Item]): |
| | """Skip tests on specified device TYPES (cpu/cuda/npu).""" |
| | for item in items: |
| | marker = item.get_closest_marker("runs_on") |
| | if not marker: |
| | continue |
| |
|
| | devices = marker.args[0] |
| | if isinstance(devices, str): |
| | devices = [devices] |
| |
|
| | if CURRENT_DEVICE not in devices: |
| | item.add_marker(pytest.mark.skip(reason=f"test requires one of {devices} (current: {CURRENT_DEVICE})")) |
| |
|
| |
|
| | def _handle_slow_tests(items: list[Item]): |
| | """Skip slow tests unless RUN_SLOW is enabled.""" |
| | if not is_env_enabled("RUN_SLOW"): |
| | skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)") |
| | for item in items: |
| | if "slow" in item.keywords: |
| | item.add_marker(skip_slow) |
| |
|
| |
|
| | def _get_visible_devices_env() -> str | None: |
| | """Return device visibility env var name.""" |
| | if CURRENT_DEVICE == "cuda": |
| | return "CUDA_VISIBLE_DEVICES" |
| | elif CURRENT_DEVICE == "npu": |
| | return "ASCEND_RT_VISIBLE_DEVICES" |
| | else: |
| | return None |
| |
|
| |
|
| | def _handle_device_visibility(items: list[Item]): |
| | """Handle device visibility based on test markers.""" |
| | env_key = _get_visible_devices_env() |
| | if env_key is None or CURRENT_DEVICE in ("cpu", "mps"): |
| | return |
| |
|
| | |
| | visible_devices_env = os.environ.get(env_key) |
| | if visible_devices_env is None: |
| | available = get_device_count() |
| | else: |
| | visible_devices = [v for v in visible_devices_env.split(",") if v != ""] |
| | available = len(visible_devices) |
| |
|
| | for item in items: |
| | marker = item.get_closest_marker("require_distributed") |
| | if not marker: |
| | continue |
| |
|
| | required = marker.args[0] if marker.args else 2 |
| | if available < required: |
| | item.add_marker(pytest.mark.skip(reason=f"test requires {required} devices, but only {available} visible")) |
| |
|
| |
|
| | def pytest_collection_modifyitems(config: Config, items: list[Item]): |
| | """Modify test collection based on markers and environment.""" |
| | |
| | skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") |
| | for item in items: |
| | if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"): |
| | item.add_marker(skip_bc) |
| |
|
| | _handle_slow_tests(items) |
| | _handle_runs_on(items) |
| | _handle_device_visibility(items) |
| |
|
| |
|
| | @pytest.fixture(autouse=True) |
| | def _cleanup_distributed_state(): |
| | """Cleanup distributed state after each test.""" |
| | yield |
| | if dist.is_initialized(): |
| | dist.destroy_process_group() |
| |
|
| |
|
| | @pytest.fixture(autouse=True) |
| | def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None: |
| | """Set environment variables for distributed tests if specific devices are requested.""" |
| | env_key = _get_visible_devices_env() |
| | if not env_key: |
| | return |
| |
|
| | |
| | old_value = os.environ.get(env_key) |
| |
|
| | marker = request.node.get_closest_marker("require_distributed") |
| | if marker: |
| | required = marker.args[0] if marker.args else 2 |
| | specific_devices = marker.args[1] if len(marker.args) > 1 else None |
| |
|
| | if specific_devices: |
| | devices_str = ",".join(map(str, specific_devices)) |
| | else: |
| | devices_str = ",".join(str(i) for i in range(required)) |
| |
|
| | monkeypatch.setenv(env_key, devices_str) |
| | monkeypatch.syspath_prepend(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
| | else: |
| | if old_value: |
| | visible_devices = [v for v in old_value.split(",") if v != ""] |
| | monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") |
| | else: |
| | monkeypatch.setenv(env_key, "0") |
| |
|
| | if CURRENT_DEVICE == "cuda": |
| | monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) |
| | elif CURRENT_DEVICE == "npu": |
| | monkeypatch.setattr(torch.npu, "device_count", lambda: 1) |
| |
|
| |
|
| | @pytest.fixture |
| | def fix_valuehead_cpu_loading(): |
| | """Fix valuehead model loading.""" |
| | patch_valuehead_model() |
| |
|