File size: 848 Bytes
2d31cef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from pathlib import Path

import onnxruntime as ort

from src.models.inference import build_tokenizer, predict


class PredictPipeline:
    def __init__(
        self,
        onnx_path: Path,
        mode: str,
        model_name: str,
    ):
        self.session = ort.InferenceSession(str(onnx_path))
        self.tokenizer = build_tokenizer(mode)
        self.mode = mode
        self.model_name = model_name

    def run(
        self,
        samples: list[dict],
        max_len: int = 256,
        batch_size: int = 32,
        deduplicate: bool = False,
    ) -> list[dict]:
        return predict(
            samples=samples,
            session=self.session,
            tokenizer=self.tokenizer,
            mode=self.mode,
            max_len=max_len,
            batch_size=batch_size,
            deduplicate=deduplicate,
        )