File size: 2,375 Bytes
eab2256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Utilities to load persisted artifacts and score new text."""

from __future__ import annotations

from pathlib import Path
from typing import Dict, List, Tuple

import joblib

from .config import Config, load_config
from .logging_utils import get_logger

LOGGER = get_logger(__name__)


def load_artifacts(config: Config | None = None):
    if config is None:
        config = load_config()

    artifact_dir = Path(config.model.get("artifact_dir", "artifacts"))
    artifact_path = artifact_dir / config.model.get("pipeline_filename", "sentiment_pipeline.joblib")

    if not artifact_path.exists():
        raise FileNotFoundError(
            f"Model artifact not found at {artifact_path!s}. Train the model by running `python scripts/train.py`."
        )

    data = joblib.load(artifact_path)
    pipeline = data["pipeline"] if isinstance(data, dict) and "pipeline" in data else data
    metrics: Dict[str, float] = data.get("metrics", {}) if isinstance(data, dict) else {}

    LOGGER.info("Loaded pipeline from %s", artifact_path)
    return pipeline, metrics


def predict(texts: List[str], config: Config | None = None) -> Tuple[List[str], List[Dict[str, float]]]:
    pipeline, _ = load_artifacts(config)
    probabilities = pipeline.predict_proba(texts)
    labels = pipeline.classes_

    predicted_labels = pipeline.predict(texts)
    probability_dicts: List[Dict[str, float]] = []
    for prob in probabilities:
        probability_dicts.append({label: float(score) for label, score in zip(labels, prob)})

    return predicted_labels.tolist(), probability_dicts


def predict_with_threshold(text: str, config: Config | None = None) -> Tuple[str, Dict[str, float]]:
    config = config or load_config()
    labels, probabilities = predict([text], config)
    probability = probabilities[0]

    thresholds = config.training.get("probability_thresholds", {})
    positive_threshold = thresholds.get("positive", 0.5)
    negative_threshold = thresholds.get("negative", 0.5)

    selected_label = labels[0]
    if probability.get("positive", 0.0) >= positive_threshold:
        selected_label = "positive"
    elif probability.get("negative", 0.0) >= negative_threshold:
        selected_label = "negative"
    else:
        selected_label = "neutral"

    return selected_label, probability


__all__ = ["load_artifacts", "predict", "predict_with_threshold"]