Muhammad7777's picture
Deploy application
7150d27
from pathlib import Path
import numpy as np
import onnxruntime as ort
from .exceptions import ModelNotFound
class OnnxPredictor:
def __init__(self, model_path: Path, use_gpu: bool = False):
"""
Initializes the ONNX Runtime session.
Args:
model_path: Path to the .onnx model file.
use_gpu: Whether to use the GPU for inference. Defaults to False.
Raises:
ModelNotFound: If the model file does not exist at the given path.
"""
if not model_path.exists():
raise ModelNotFound(f"ONNX model file not found at: {model_path}")
providers = ["CPUExecutionProvider"]
if use_gpu:
# You can customize this list based on your target hardware
providers.insert(0, "CUDAExecutionProvider")
self.session = ort.InferenceSession(str(model_path), providers=providers)
self.input_name = self.session.get_inputs()[0].name
self.hints_name = self.session.get_inputs()[1].name
self.output_name = self.session.get_outputs()[0].name
def predict(self, input_ids: np.ndarray, hints: np.ndarray) -> np.ndarray:
"""
Runs inference on a batch of tokenized input IDs.
Args:
input_ids: A numpy array of shape (batch_size, sequence_length).
hints: A numpy array of shape (batch_size, sequence_length).
Returns:
A numpy array of logits of shape (batch_size, sequence_length, num_classes).
"""
ort_inputs = {self.input_name: input_ids, self.hints_name: hints}
# The output is a list, we are interested in the first element
logits = self.session.run([self.output_name], ort_inputs)[0]
return logits