| from pathlib import Path | |
| import onnxruntime as ort | |
| from src.models.inference import build_tokenizer, predict | |
| class PredictPipeline: | |
| def __init__( | |
| self, | |
| onnx_path: Path, | |
| mode: str, | |
| model_name: str, | |
| ): | |
| self.session = ort.InferenceSession(str(onnx_path)) | |
| self.tokenizer = build_tokenizer(mode) | |
| self.mode = mode | |
| self.model_name = model_name | |
| def run( | |
| self, | |
| samples: list[dict], | |
| max_len: int = 256, | |
| batch_size: int = 32, | |
| deduplicate: bool = False, | |
| ) -> list[dict]: | |
| return predict( | |
| samples=samples, | |
| session=self.session, | |
| tokenizer=self.tokenizer, | |
| mode=self.mode, | |
| max_len=max_len, | |
| batch_size=batch_size, | |
| deduplicate=deduplicate, | |
| ) | |