File size: 1,748 Bytes
6af0821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path

import joblib
import numpy as np
from sentence_transformers import SentenceTransformer


DEFAULT_ARTIFACT_DIR = Path("artifacts/sentence-function-classifier")


@dataclass(frozen=True)
class Prediction:
    label: str
    confidence: float
    probabilities: dict[str, float]


class SentenceFunctionClassifier:
    def __init__(self, artifact_dir: str | Path = DEFAULT_ARTIFACT_DIR) -> None:
        self.artifact_dir = Path(artifact_dir)
        self.classifier = joblib.load(self.artifact_dir / "classifier.joblib")
        self.label_encoder = joblib.load(self.artifact_dir / "label_encoder.joblib")

        metadata_path = self.artifact_dir / "metadata.json"
        self.metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
        self.embedding_model_name = self.metadata["embedding_model"]
        self.embedding_model = SentenceTransformer(self.embedding_model_name)

    def predict(self, sentence: str) -> Prediction:
        sentence = sentence.strip()
        if not sentence:
            raise ValueError("Please enter a sentence.")

        embedding = self.embedding_model.encode([sentence], normalize_embeddings=True)
        probabilities = self.classifier.predict_proba(embedding)[0]
        best_index = int(np.argmax(probabilities))
        labels = list(self.label_encoder.classes_)
        scores = {label: float(probabilities[index]) for index, label in enumerate(labels)}

        return Prediction(
            label=labels[best_index],
            confidence=float(probabilities[best_index]),
            probabilities=dict(sorted(scores.items(), key=lambda item: item[1], reverse=True)),
        )