| """Thin wrapper around shared gpu_utils helpers.""" | |
| from importlib.util import module_from_spec, spec_from_file_location | |
| from pathlib import Path | |
| _PROJECT_ROOT = Path(__file__).resolve().parents[3] | |
| _SHARED_GPU_UTILS = _PROJECT_ROOT / "utils" / "gpu_utils.py" | |
| _spec = spec_from_file_location( | |
| "wavegen_shared_gpu_utils", _SHARED_GPU_UTILS | |
| ) | |
| if _spec is None or _spec.loader is None: | |
| raise ModuleNotFoundError( | |
| f"Shared gpu_utils module not found at {_SHARED_GPU_UTILS}" | |
| ) | |
| _module = module_from_spec(_spec) | |
| _spec.loader.exec_module(_module) | |
| DEFAULT_THRESHOLD_MB = _module.DEFAULT_THRESHOLD_MB | |
| query_gpu_memory = _module.query_gpu_memory | |
| select_gpus = _module.select_gpus | |
| __all__ = [ | |
| "DEFAULT_THRESHOLD_MB", | |
| "query_gpu_memory", | |
| "select_gpus", | |
| ] | |