| import logging |
| from pathlib import Path |
|
|
| import onnxruntime as ort |
|
|
| from src.models.inference import build_tokenizer, predict |
| from src.pipelines.fasttext_pipeline import FastTextPipeline |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| class PredictAllPipeline: |
| def __init__(self): |
| self.models: dict[str, dict] = {} |
| self.fasttext_models: dict[str, FastTextPipeline] = {} |
|
|
| def add_model( |
| self, |
| name: str, |
| onnx_path: Path, |
| mode: str, |
| ) -> None: |
| log.info(f"Adding {name} to PredictAllPipeline") |
| self.models[name] = { |
| "session": ort.InferenceSession(str(onnx_path)), |
| "tokenizer": build_tokenizer(mode), |
| "mode": mode, |
| } |
|
|
| def add_fasttext(self, name: str, model_path: Path) -> None: |
| log.info(f"Adding {name} (fastText) to PredictAllPipeline") |
| self.fasttext_models[name] = FastTextPipeline(model_path, name) |
|
|
| def run( |
| self, |
| samples: list[dict], |
| max_len: int = 256, |
| batch_size: int = 32, |
| deduplicate: bool = False, |
| ) -> dict[str, list[dict]]: |
| results: dict[str, list[dict]] = {} |
| for name, m in self.models.items(): |
| results[name] = predict( |
| samples=samples, |
| session=m["session"], |
| tokenizer=m["tokenizer"], |
| mode=m["mode"], |
| max_len=max_len, |
| batch_size=batch_size, |
| deduplicate=deduplicate, |
| ) |
| for name, ft in self.fasttext_models.items(): |
| results[name] = ft.run( |
| samples=samples, |
| deduplicate=deduplicate, |
| ) |
| return results |
|
|