File size: 5,902 Bytes
70b2ea0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from __future__ import annotations

import json
from pathlib import Path

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer


PROJECT_DIR = Path(__file__).resolve().parent
DEFAULT_MODEL_DIR = PROJECT_DIR / "artifacts" / "large_model" / "best_model"
DEFAULT_LABELS_PATH = PROJECT_DIR / "data" / "processed_large" / "label_mapping.json"


class ClassifierError(RuntimeError):
    pass


class ArticleClassifier:
    def __init__(
        self,
        model_dir: Path = DEFAULT_MODEL_DIR,
        labels_path: Path = DEFAULT_LABELS_PATH,
        max_length: int = 256,
    ) -> None:
        self.model_dir = Path(model_dir)
        self.labels_path = Path(labels_path)
        self.max_length = max_length
        self.device = torch.device(
            "mps"
            if torch.backends.mps.is_available()
            else "cuda"
            if torch.cuda.is_available()
            else "cpu"
        )

        if not self.labels_path.exists():
            raise ClassifierError(
                f"Failed to initialize classifier at labels loading stage: labels file not found at {self.labels_path}"
            )
        if not self.model_dir.exists():
            raise ClassifierError(
                f"Failed to initialize classifier at model loading stage: model directory not found at {self.model_dir}"
            )

        try:
            with self.labels_path.open("r", encoding="utf-8") as fh:
                self.label2id = json.load(fh)
        except Exception as exc:
            raise ClassifierError(
                f"Failed to initialize classifier at labels loading stage: {exc}"
            ) from exc

        if not isinstance(self.label2id, dict) or not self.label2id:
            raise ClassifierError(
                "Failed to initialize classifier at labels loading stage: label mapping is empty or invalid"
            )
        self.id2label = {idx: label for label, idx in self.label2id.items()}

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
            self.model = AutoModelForSequenceClassification.from_pretrained(self.model_dir)
            self.model.to(self.device)
            self.model.eval()
        except Exception as exc:
            raise ClassifierError(
                f"Failed to initialize classifier at model loading stage: {exc}"
            ) from exc

    @property
    def labels(self) -> list[str]:
        return [self.id2label[idx] for idx in sorted(self.id2label)]

    @staticmethod
    def build_input_text(title: str, abstract: str) -> str:
        clean_title = " ".join(title.split()).strip()
        clean_abstract = " ".join(abstract.split()).strip()
        if clean_abstract:
            return f"title: {clean_title} abstract: {clean_abstract}"
        return f"title: {clean_title}"

    def predict(self, title: str, abstract: str = "") -> list[dict[str, float | str]]:
        if not isinstance(title, str):
            raise ValueError("Input validation error in predict: title must be a string.")
        if not isinstance(abstract, str):
            raise ValueError("Input validation error in predict: abstract must be a string.")
        if not title.strip() and not abstract.strip():
            raise ValueError(
                "Input validation error in predict: please provide at least a title or an abstract."
            )

        text = self.build_input_text(title=title, abstract=abstract)
        try:
            encoded = self.tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                max_length=self.max_length,
            )
            encoded = {key: value.to(self.device) for key, value in encoded.items()}
        except Exception as exc:
            raise ClassifierError(f"Failed during tokenization stage: {exc}") from exc

        try:
            with torch.inference_mode():
                logits = self.model(**encoded).logits
                probabilities = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu()
        except Exception as exc:
            raise ClassifierError(f"Failed during model inference stage: {exc}") from exc

        results: list[dict[str, float | str]] = []
        try:
            for class_id, probability in enumerate(probabilities.tolist()):
                results.append(
                    {
                        "label": self.id2label[class_id],
                        "probability": float(probability),
                    }
                )
        except Exception as exc:
            raise ClassifierError(f"Failed during prediction formatting stage: {exc}") from exc

        results.sort(key=lambda item: item["probability"], reverse=True)
        return results

    @staticmethod
    def select_top_95(
        predictions: list[dict[str, float | str]],
    ) -> list[dict[str, float | str]]:
        return ArticleClassifier.select_top_k_by_probability_mass(
            predictions=predictions,
            threshold=0.95,
        )

    @staticmethod
    def select_top_k_by_probability_mass(
        predictions: list[dict[str, float | str]],
        threshold: float = 0.95,
    ) -> list[dict[str, float | str]]:
        if not 0 < threshold <= 1:
            raise ValueError("Probability mass threshold must be in the interval (0, 1].")

        cumulative_probability = 0.0
        top_predictions: list[dict[str, float | str]] = []

        for item in predictions:
            top_predictions.append(item)
            cumulative_probability += float(item["probability"])
            if cumulative_probability >= 0.95:
                break

        return top_predictions

    def predict_top_95(self, title: str, abstract: str = "") -> list[dict[str, float | str]]:
        predictions = self.predict(title=title, abstract=abstract)
        return self.select_top_95(predictions)