MarcusBennevall's picture
Upload folder using huggingface_hub
6af0821 verified
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
import joblib
import numpy as np
from sentence_transformers import SentenceTransformer
DEFAULT_ARTIFACT_DIR = Path("artifacts/sentence-function-classifier")
@dataclass(frozen=True)
class Prediction:
label: str
confidence: float
probabilities: dict[str, float]
class SentenceFunctionClassifier:
def __init__(self, artifact_dir: str | Path = DEFAULT_ARTIFACT_DIR) -> None:
self.artifact_dir = Path(artifact_dir)
self.classifier = joblib.load(self.artifact_dir / "classifier.joblib")
self.label_encoder = joblib.load(self.artifact_dir / "label_encoder.joblib")
metadata_path = self.artifact_dir / "metadata.json"
self.metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
self.embedding_model_name = self.metadata["embedding_model"]
self.embedding_model = SentenceTransformer(self.embedding_model_name)
def predict(self, sentence: str) -> Prediction:
sentence = sentence.strip()
if not sentence:
raise ValueError("Please enter a sentence.")
embedding = self.embedding_model.encode([sentence], normalize_embeddings=True)
probabilities = self.classifier.predict_proba(embedding)[0]
best_index = int(np.argmax(probabilities))
labels = list(self.label_encoder.classes_)
scores = {label: float(probabilities[index]) for index, label in enumerate(labels)}
return Prediction(
label=labels[best_index],
confidence=float(probabilities[best_index]),
probabilities=dict(sorted(scores.items(), key=lambda item: item[1], reverse=True)),
)