sv-task / src /pipelines /predict_all_pipeline.py
lamossta's picture
pipeline classes
2d31cef
import logging
from pathlib import Path
import onnxruntime as ort
from src.models.inference import build_tokenizer, predict
from src.pipelines.fasttext_pipeline import FastTextPipeline
log = logging.getLogger(__name__)
class PredictAllPipeline:
def __init__(self):
self.models: dict[str, dict] = {}
self.fasttext_models: dict[str, FastTextPipeline] = {}
def add_model(
self,
name: str,
onnx_path: Path,
mode: str,
) -> None:
log.info(f"Adding {name} to PredictAllPipeline")
self.models[name] = {
"session": ort.InferenceSession(str(onnx_path)),
"tokenizer": build_tokenizer(mode),
"mode": mode,
}
def add_fasttext(self, name: str, model_path: Path) -> None:
log.info(f"Adding {name} (fastText) to PredictAllPipeline")
self.fasttext_models[name] = FastTextPipeline(model_path, name)
def run(
self,
samples: list[dict],
max_len: int = 256,
batch_size: int = 32,
deduplicate: bool = False,
) -> dict[str, list[dict]]:
results: dict[str, list[dict]] = {}
for name, m in self.models.items():
results[name] = predict(
samples=samples,
session=m["session"],
tokenizer=m["tokenizer"],
mode=m["mode"],
max_len=max_len,
batch_size=batch_size,
deduplicate=deduplicate,
)
for name, ft in self.fasttext_models.items():
results[name] = ft.run(
samples=samples,
deduplicate=deduplicate,
)
return results