from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path import joblib import numpy as np from sentence_transformers import SentenceTransformer DEFAULT_ARTIFACT_DIR = Path("artifacts/sentence-function-classifier") @dataclass(frozen=True) class Prediction: label: str confidence: float probabilities: dict[str, float] class SentenceFunctionClassifier: def __init__(self, artifact_dir: str | Path = DEFAULT_ARTIFACT_DIR) -> None: self.artifact_dir = Path(artifact_dir) self.classifier = joblib.load(self.artifact_dir / "classifier.joblib") self.label_encoder = joblib.load(self.artifact_dir / "label_encoder.joblib") metadata_path = self.artifact_dir / "metadata.json" self.metadata = json.loads(metadata_path.read_text(encoding="utf-8")) self.embedding_model_name = self.metadata["embedding_model"] self.embedding_model = SentenceTransformer(self.embedding_model_name) def predict(self, sentence: str) -> Prediction: sentence = sentence.strip() if not sentence: raise ValueError("Please enter a sentence.") embedding = self.embedding_model.encode([sentence], normalize_embeddings=True) probabilities = self.classifier.predict_proba(embedding)[0] best_index = int(np.argmax(probabilities)) labels = list(self.label_encoder.classes_) scores = {label: float(probabilities[index]) for index, label in enumerate(labels)} return Prediction( label=labels[best_index], confidence=float(probabilities[best_index]), probabilities=dict(sorted(scores.items(), key=lambda item: item[1], reverse=True)), )