Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Utilities for version compatibility checking. | |
| Specifically in use to prevent some newer physicsnemo modules from being used with | |
| and older version of pytorch. | |
| """ | |
| import importlib | |
| from typing import Optional | |
| from packaging import version | |
| # Dictionary mapping module paths to their version requirements | |
| # This can be expanded as needed for different modules | |
| VERSION_REQUIREMENTS = { | |
| "physicsnemo.distributed.shard_tensor": {"torch": "2.5.9"}, | |
| "device_mesh": {"torch": "2.4.0"}, | |
| } | |
| def check_min_version( | |
| package_name: str, | |
| min_version: str, | |
| error_msg: Optional[str] = None, | |
| hard_fail: bool = True, | |
| ) -> bool: | |
| """ | |
| Check if an installed package meets the minimum version requirement. | |
| Args: | |
| package_name: Name of the package to check | |
| min_version: Minimum required version string (e.g. '2.6.0') | |
| error_msg: Optional custom error message | |
| hard_fail: Whether to raise an ImportError if the version requirement is not met | |
| Returns: | |
| True if version requirement is met | |
| Raises: | |
| ImportError: If package is not installed or version is too low | |
| """ | |
| try: | |
| package = importlib.import_module(package_name) | |
| package_version = getattr(package, "__version__", "0.0.0") | |
| except ImportError: | |
| if hard_fail: | |
| raise ImportError(f"Package {package_name} is required but not installed.") | |
| else: | |
| return False | |
| if version.parse(package_version) < version.parse(min_version): | |
| msg = ( | |
| error_msg | |
| or f"{package_name} version {min_version} or higher is required, but found {package_version}" | |
| ) | |
| if hard_fail: | |
| raise ImportError(msg) | |
| else: | |
| return False | |
| return True | |
| def check_module_requirements(module_path: str) -> None: | |
| """ | |
| Check all version requirements for a specific module. | |
| Args: | |
| module_path: The import path of the module to check requirements for | |
| Raises: | |
| ImportError: If any requirement is not met | |
| """ | |
| if module_path not in VERSION_REQUIREMENTS: | |
| return | |
| for package, min_version in VERSION_REQUIREMENTS[module_path].items(): | |
| check_min_version(package, min_version) | |
| def require_version(package_name: str, min_version: str): | |
| """ | |
| Decorator that prevents a function from being called unless the | |
| specified package meets the minimum version requirement. | |
| Args: | |
| package_name: Name of the package to check | |
| min_version: Minimum required version string (e.g. '2.3') | |
| Returns: | |
| Decorator function that checks version requirement before execution | |
| Example: | |
| @require_version("torch", "2.3") | |
| def my_function(): | |
| # This function will only execute if torch >= 2.3 | |
| pass | |
| """ | |
| def decorator(func): | |
| import functools | |
| def wrapper(*args, **kwargs): | |
| # Verify the package meets minimum version before executing | |
| check_min_version(package_name, min_version) | |
| # If we get here, version check passed | |
| return func(*args, **kwargs) | |
| return wrapper | |
| return decorator | |