File size: 1,709 Bytes
2d31cef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import logging
from pathlib import Path

import onnxruntime as ort

from src.models.inference import build_tokenizer, predict
from src.pipelines.fasttext_pipeline import FastTextPipeline

log = logging.getLogger(__name__)


class PredictAllPipeline:
    def __init__(self):
        self.models: dict[str, dict] = {}
        self.fasttext_models: dict[str, FastTextPipeline] = {}

    def add_model(
        self,
        name: str,
        onnx_path: Path,
        mode: str,
    ) -> None:
        log.info(f"Adding {name} to PredictAllPipeline")
        self.models[name] = {
            "session": ort.InferenceSession(str(onnx_path)),
            "tokenizer": build_tokenizer(mode),
            "mode": mode,
        }

    def add_fasttext(self, name: str, model_path: Path) -> None:
        log.info(f"Adding {name} (fastText) to PredictAllPipeline")
        self.fasttext_models[name] = FastTextPipeline(model_path, name)

    def run(
        self,
        samples: list[dict],
        max_len: int = 256,
        batch_size: int = 32,
        deduplicate: bool = False,
    ) -> dict[str, list[dict]]:
        results: dict[str, list[dict]] = {}
        for name, m in self.models.items():
            results[name] = predict(
                samples=samples,
                session=m["session"],
                tokenizer=m["tokenizer"],
                mode=m["mode"],
                max_len=max_len,
                batch_size=batch_size,
                deduplicate=deduplicate,
            )
        for name, ft in self.fasttext_models.items():
            results[name] = ft.run(
                samples=samples,
                deduplicate=deduplicate,
            )
        return results