| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """LlamaFactory test configuration. |
| |
| Contains shared fixtures, pytest configuration, and custom markers. |
| """ |
|
|
| import os |
| import sys |
|
|
| import pytest |
| import torch |
| import torch.distributed as dist |
| from pytest import Config, FixtureRequest, Item, MonkeyPatch |
|
|
| from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count |
| from llamafactory.v1.utils.env import is_env_enabled |
| from llamafactory.v1.utils.packages import is_transformers_version_greater_than |
|
|
|
|
| CURRENT_DEVICE = get_current_accelerator().type |
|
|
|
|
| def pytest_configure(config: Config): |
| """Register custom pytest markers.""" |
| config.addinivalue_line( |
| "markers", |
| "slow: marks tests as slow (deselect with '-m \"not slow\"' or set RUN_SLOW=1 to run)", |
| ) |
| config.addinivalue_line( |
| "markers", |
| "runs_on: test requires specific device type, e.g., @pytest.mark.runs_on(['cuda'])", |
| ) |
| config.addinivalue_line( |
| "markers", |
| "require_distributed(num_devices): allow multi-device execution (default: 2)", |
| ) |
|
|
|
|
| def _handle_runs_on(items: list[Item]): |
| """Skip tests on specified device TYPES (cpu/cuda/npu).""" |
| for item in items: |
| marker = item.get_closest_marker("runs_on") |
| if not marker: |
| continue |
|
|
| devices = marker.args[0] |
| if isinstance(devices, str): |
| devices = [devices] |
|
|
| if CURRENT_DEVICE not in devices: |
| item.add_marker(pytest.mark.skip(reason=f"test requires one of {devices} (current: {CURRENT_DEVICE})")) |
|
|
|
|
| def _handle_slow_tests(items: list[Item]): |
| """Skip slow tests unless RUN_SLOW is enabled.""" |
| if not is_env_enabled("RUN_SLOW"): |
| skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)") |
| for item in items: |
| if "slow" in item.keywords: |
| item.add_marker(skip_slow) |
|
|
|
|
| def _get_visible_devices_env() -> str | None: |
| """Return device visibility env var name.""" |
| if CURRENT_DEVICE == "cuda": |
| return "CUDA_VISIBLE_DEVICES" |
| elif CURRENT_DEVICE == "npu": |
| return "ASCEND_RT_VISIBLE_DEVICES" |
| else: |
| return None |
|
|
|
|
| def _handle_device_visibility(items: list[Item]): |
| """Handle device visibility based on test markers.""" |
| env_key = _get_visible_devices_env() |
| if env_key is None or CURRENT_DEVICE in ("cpu", "mps"): |
| return |
|
|
| |
| visible_devices_env = os.environ.get(env_key) |
| if visible_devices_env is None: |
| available = get_device_count() |
| else: |
| visible_devices = [v for v in visible_devices_env.split(",") if v != ""] |
| available = len(visible_devices) |
|
|
| for item in items: |
| marker = item.get_closest_marker("require_distributed") |
| if not marker: |
| continue |
|
|
| required = marker.args[0] if marker.args else 2 |
| if available < required: |
| item.add_marker(pytest.mark.skip(reason=f"test requires {required} devices, but only {available} visible")) |
|
|
|
|
| def pytest_collection_modifyitems(config: Config, items: list[Item]): |
| """Modify test collection based on markers and environment.""" |
| |
| skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") |
| for item in items: |
| if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"): |
| item.add_marker(skip_bc) |
|
|
| _handle_slow_tests(items) |
| _handle_runs_on(items) |
| _handle_device_visibility(items) |
|
|
|
|
| @pytest.fixture(autouse=True) |
| def _cleanup_distributed_state(): |
| """Cleanup distributed state after each test.""" |
| yield |
| if dist.is_initialized(): |
| dist.destroy_process_group() |
|
|
|
|
| @pytest.fixture(autouse=True) |
| def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None: |
| """Set environment variables for distributed tests if specific devices are requested.""" |
| env_key = _get_visible_devices_env() |
| if not env_key: |
| return |
|
|
| |
| old_value = os.environ.get(env_key) |
|
|
| marker = request.node.get_closest_marker("require_distributed") |
| if marker: |
| required = marker.args[0] if marker.args else 2 |
| specific_devices = marker.args[1] if len(marker.args) > 1 else None |
|
|
| if specific_devices: |
| devices_str = ",".join(map(str, specific_devices)) |
| else: |
| devices_str = ",".join(str(i) for i in range(required)) |
|
|
| monkeypatch.setenv(env_key, devices_str) |
|
|
| |
| project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| if project_root not in sys.path: |
| sys.path.insert(0, project_root) |
|
|
| os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "") |
|
|
| else: |
| if old_value: |
| visible_devices = [v for v in old_value.split(",") if v != ""] |
| monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") |
| else: |
| monkeypatch.setenv(env_key, "0") |
|
|
| if CURRENT_DEVICE == "cuda": |
| monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) |
| elif CURRENT_DEVICE == "npu": |
| monkeypatch.setattr(torch.npu, "device_count", lambda: 1) |
|
|