| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import joblib |
| import numpy as np |
| from sentence_transformers import SentenceTransformer |
|
|
|
|
| DEFAULT_ARTIFACT_DIR = Path("artifacts/sentence-function-classifier") |
|
|
|
|
| @dataclass(frozen=True) |
| class Prediction: |
| label: str |
| confidence: float |
| probabilities: dict[str, float] |
|
|
|
|
| class SentenceFunctionClassifier: |
| def __init__(self, artifact_dir: str | Path = DEFAULT_ARTIFACT_DIR) -> None: |
| self.artifact_dir = Path(artifact_dir) |
| self.classifier = joblib.load(self.artifact_dir / "classifier.joblib") |
| self.label_encoder = joblib.load(self.artifact_dir / "label_encoder.joblib") |
|
|
| metadata_path = self.artifact_dir / "metadata.json" |
| self.metadata = json.loads(metadata_path.read_text(encoding="utf-8")) |
| self.embedding_model_name = self.metadata["embedding_model"] |
| self.embedding_model = SentenceTransformer(self.embedding_model_name) |
|
|
| def predict(self, sentence: str) -> Prediction: |
| sentence = sentence.strip() |
| if not sentence: |
| raise ValueError("Please enter a sentence.") |
|
|
| embedding = self.embedding_model.encode([sentence], normalize_embeddings=True) |
| probabilities = self.classifier.predict_proba(embedding)[0] |
| best_index = int(np.argmax(probabilities)) |
| labels = list(self.label_encoder.classes_) |
| scores = {label: float(probabilities[index]) for index, label in enumerate(labels)} |
|
|
| return Prediction( |
| label=labels[best_index], |
| confidence=float(probabilities[best_index]), |
| probabilities=dict(sorted(scores.items(), key=lambda item: item[1], reverse=True)), |
| ) |
|
|