File size: 1,179 Bytes
2979822 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
"""ONNX model loader with provider auto-detection."""
import onnxruntime as ort
from typing import Tuple, Optional
from pathlib import Path
def load_model(model_path: str) -> Tuple[Optional[ort.InferenceSession], Optional[str]]:
"""Load ONNX model. Return (session, input_name) or (None, None) on failure."""
if not Path(model_path).exists():
return None, None
try:
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
)
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
available_providers = ort.get_available_providers()
preferred_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
providers = [p for p in preferred_providers if p in available_providers]
if not providers:
providers = available_providers
ort_session = ort.InferenceSession(
model_path, sess_options=sess_options, providers=providers
)
input_name = ort_session.get_inputs()[0].name
return ort_session, input_name
except Exception:
return None, None
|