File size: 848 Bytes
2d31cef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | from pathlib import Path
import onnxruntime as ort
from src.models.inference import build_tokenizer, predict
class PredictPipeline:
def __init__(
self,
onnx_path: Path,
mode: str,
model_name: str,
):
self.session = ort.InferenceSession(str(onnx_path))
self.tokenizer = build_tokenizer(mode)
self.mode = mode
self.model_name = model_name
def run(
self,
samples: list[dict],
max_len: int = 256,
batch_size: int = 32,
deduplicate: bool = False,
) -> list[dict]:
return predict(
samples=samples,
session=self.session,
tokenizer=self.tokenizer,
mode=self.mode,
max_len=max_len,
batch_size=batch_size,
deduplicate=deduplicate,
)
|