| import logging |
| import os |
| import shutil |
| from functools import lru_cache |
| from typing import Optional |
|
|
| from hbutils.system import pip_install |
|
|
|
|
| def _ensure_onnxruntime(): |
| try: |
| import onnxruntime |
| except (ImportError, ModuleNotFoundError): |
| logging.warning('Onnx runtime not installed, preparing to install ...') |
| if shutil.which('nvidia-smi'): |
| logging.info('Installing onnxruntime-gpu ...') |
| pip_install(['onnxruntime-gpu'], silent=True) |
| else: |
| logging.info('Installing onnxruntime (cpu) ...') |
| pip_install(['onnxruntime'], silent=True) |
|
|
|
|
| _ensure_onnxruntime() |
| from onnxruntime import get_available_providers, get_all_providers, InferenceSession, SessionOptions, \ |
| GraphOptimizationLevel |
|
|
| alias = { |
| 'gpu': "CUDAExecutionProvider", |
| "trt": "TensorrtExecutionProvider", |
| } |
|
|
|
|
| def get_onnx_provider(provider: Optional[str] = None): |
| if not provider: |
| if "CUDAExecutionProvider" in get_available_providers(): |
| return "CUDAExecutionProvider" |
| else: |
| return "CPUExecutionProvider" |
| elif provider.lower() in alias: |
| return alias[provider.lower()] |
| else: |
| for p in get_all_providers(): |
| if provider.lower() == p.lower() or f'{provider}ExecutionProvider'.lower() == p.lower(): |
| return p |
|
|
| raise ValueError(f'One of the {get_all_providers()!r} expected, ' |
| f'but unsupported provider {provider!r} found.') |
|
|
|
|
| @lru_cache() |
| def _open_onnx_model(ckpt: str, provider: str = None) -> InferenceSession: |
| options = SessionOptions() |
| options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL |
| provider = provider or get_onnx_provider() |
| if provider == "CPUExecutionProvider": |
| options.intra_op_num_threads = os.cpu_count() |
|
|
| logging.info(f'Model {ckpt!r} loaded with provider {provider!r}') |
| return InferenceSession(ckpt, options, [provider]) |
|
|