sv-task / src /pipelines /predict_pipeline.py
lamossta's picture
pipeline classes
2d31cef
raw
history blame contribute delete
848 Bytes
from pathlib import Path
import onnxruntime as ort
from src.models.inference import build_tokenizer, predict
class PredictPipeline:
def __init__(
self,
onnx_path: Path,
mode: str,
model_name: str,
):
self.session = ort.InferenceSession(str(onnx_path))
self.tokenizer = build_tokenizer(mode)
self.mode = mode
self.model_name = model_name
def run(
self,
samples: list[dict],
max_len: int = 256,
batch_size: int = 32,
deduplicate: bool = False,
) -> list[dict]:
return predict(
samples=samples,
session=self.session,
tokenizer=self.tokenizer,
mode=self.mode,
max_len=max_len,
batch_size=batch_size,
deduplicate=deduplicate,
)