# Copyright 2020-2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import importlib.metadata import importlib.util import warnings from contextlib import contextmanager from packaging.version import Version LIGER_KERNEL_MIN_VERSION = "0.7.0" PACKAGE_DISTRIBUTION_MAPPING = importlib.metadata.packages_distributions() # From transformers: https://github.com/huggingface/transformers/blob/556312cd45a5e619c41b0f8adf680eab0d334324/src/transformers/utils/import_utils.py#L48-L77 def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: """Check if `pkg_name` exist, and optionally try to get its version""" spec = importlib.util.find_spec(pkg_name) package_exists = spec is not None package_version = "N/A" if package_exists and return_version: try: # importlib.metadata works with the distribution package, which may be different from the import # name (e.g. `PIL` is the import name, but `pillow` is the distribution name) distributions = PACKAGE_DISTRIBUTION_MAPPING[pkg_name] # Per PEP 503, underscores and hyphens are equivalent in package names. # Prefer the distribution that matches the (normalized) package name. normalized_pkg_name = pkg_name.replace("_", "-") if normalized_pkg_name in distributions: distribution_name = normalized_pkg_name elif pkg_name in distributions: distribution_name = pkg_name else: distribution_name = distributions[0] package_version = importlib.metadata.version(distribution_name) except (importlib.metadata.PackageNotFoundError, KeyError): # If we cannot find the metadata (because of editable install for example), try to import directly. # Note that this branch will almost never be run, so we do not import packages for nothing here package = importlib.import_module(pkg_name) package_version = getattr(package, "__version__", "N/A") if return_version: return package_exists, package_version else: return package_exists def is_deepspeed_available() -> bool: return _is_package_available("deepspeed") def is_fastapi_available() -> bool: return _is_package_available("fastapi") def is_jmespath_available() -> bool: return _is_package_available("jmespath") def is_joblib_available() -> bool: return _is_package_available("joblib") def is_liger_kernel_available(min_version: str = LIGER_KERNEL_MIN_VERSION) -> bool: _liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True) return _liger_kernel_available and Version(_liger_kernel_version) >= Version(min_version) def is_math_verify_available() -> bool: return _is_package_available("math_verify") def is_mergekit_available() -> bool: return _is_package_available("mergekit") def is_pydantic_available() -> bool: return _is_package_available("pydantic") def is_requests_available() -> bool: return _is_package_available("requests") def is_unsloth_available() -> bool: return _is_package_available("unsloth") def is_uvicorn_available() -> bool: return _is_package_available("uvicorn") def is_vllm_available(min_version: str | None = None) -> bool: _vllm_available, _vllm_version = _is_package_available("vllm", return_version=True) if _vllm_available: if not (Version("0.12.0") <= Version(_vllm_version) <= Version("0.18.0")): warnings.warn( f"TRL currently supports vLLM versions from 0.12.0 to 0.18.0. You have version {_vllm_version} " "installed. We recommend installing a supported version to avoid compatibility issues.", stacklevel=2, ) if min_version is not None and Version(_vllm_version) < Version(min_version): return False return _vllm_available def is_vllm_ascend_available() -> bool: return _is_package_available("vllm_ascend") def is_weave_available() -> bool: return _is_package_available("weave") class TRLExperimentalWarning(UserWarning): """Warning for using the 'trl.experimental' submodule.""" pass @contextmanager def suppress_warning(category): with warnings.catch_warnings(): warnings.simplefilter("ignore", category=category) yield def suppress_experimental_warning(): return suppress_warning(TRLExperimentalWarning)