WaveGen / nano_WaveGen /utils /gpu_utils.py
FangSen9000's picture
Upload nano_WaveGen
8e263cf verified
"""Thin wrapper around shared gpu_utils helpers."""
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
_SHARED_GPU_UTILS = _PROJECT_ROOT / "utils" / "gpu_utils.py"
_spec = spec_from_file_location(
"wavegen_shared_gpu_utils", _SHARED_GPU_UTILS
)
if _spec is None or _spec.loader is None:
raise ModuleNotFoundError(
f"Shared gpu_utils module not found at {_SHARED_GPU_UTILS}"
)
_module = module_from_spec(_spec)
_spec.loader.exec_module(_module)
DEFAULT_THRESHOLD_MB = _module.DEFAULT_THRESHOLD_MB
query_gpu_memory = _module.query_gpu_memory
select_gpus = _module.select_gpus
__all__ = [
"DEFAULT_THRESHOLD_MB",
"query_gpu_memory",
"select_gpus",
]