import logging from pathlib import Path import onnxruntime as ort from src.models.inference import build_tokenizer, predict from src.pipelines.fasttext_pipeline import FastTextPipeline log = logging.getLogger(__name__) class PredictAllPipeline: def __init__(self): self.models: dict[str, dict] = {} self.fasttext_models: dict[str, FastTextPipeline] = {} def add_model( self, name: str, onnx_path: Path, mode: str, ) -> None: log.info(f"Adding {name} to PredictAllPipeline") self.models[name] = { "session": ort.InferenceSession(str(onnx_path)), "tokenizer": build_tokenizer(mode), "mode": mode, } def add_fasttext(self, name: str, model_path: Path) -> None: log.info(f"Adding {name} (fastText) to PredictAllPipeline") self.fasttext_models[name] = FastTextPipeline(model_path, name) def run( self, samples: list[dict], max_len: int = 256, batch_size: int = 32, deduplicate: bool = False, ) -> dict[str, list[dict]]: results: dict[str, list[dict]] = {} for name, m in self.models.items(): results[name] = predict( samples=samples, session=m["session"], tokenizer=m["tokenizer"], mode=m["mode"], max_len=max_len, batch_size=batch_size, deduplicate=deduplicate, ) for name, ft in self.fasttext_models.items(): results[name] = ft.run( samples=samples, deduplicate=deduplicate, ) return results