File size: 6,641 Bytes
2d0ef3b
50231a8
 
59b46a2
 
50231a8
 
 
 
2d0ef3b
 
50231a8
 
2d0ef3b
 
50231a8
59b46a2
 
efddb2f
59b46a2
 
efddb2f
2d0ef3b
 
 
 
 
 
 
 
59b46a2
 
 
2d0ef3b
 
 
 
 
 
 
 
 
 
 
 
 
 
59b46a2
 
2d0ef3b
efddb2f
59b46a2
 
 
 
 
 
50231a8
2d0ef3b
 
 
 
 
2571402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d0ef3b
 
 
 
 
2571402
 
 
2d0ef3b
 
 
 
2571402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d0ef3b
50231a8
2d0ef3b
 
50231a8
 
59b46a2
2d0ef3b
59b46a2
50231a8
2d0ef3b
59b46a2
2d0ef3b
59b46a2
2d0ef3b
59b46a2
 
 
 
 
 
2d0ef3b
 
 
 
 
 
 
 
50231a8
efddb2f
50231a8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import logging
from typing import Any

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from app.core.config import settings
from app.core.exceptions import ClassificationError

logger = logging.getLogger(__name__)


class ClassifierService:
    _HYPOTHESIS_TEMPLATE = "This text is about {}."

    def __init__(self) -> None:
        self._tokenizer: Any | None = None
        self._model: Any | None = None

    def _load_model(self) -> tuple[Any, Any]:
        if self._tokenizer is None or self._model is None:
            try:
                tokenizer = AutoTokenizer.from_pretrained(
                    settings.classifier_model,
                    token=settings.huggingface_token,
                )
                model = AutoModelForSequenceClassification.from_pretrained(
                    settings.classifier_model,
                    token=settings.huggingface_token,
                )
                model.eval()
                model.to("cpu")

                if settings.enable_model_quantization:
                    try:
                        # Dynamic INT8 quantization for CPU inference.
                        quantized_model = torch.ao.quantization.quantize_dynamic(
                            model,
                            {torch.nn.Linear},
                            dtype=torch.qint8,
                        )
                        model = quantized_model
                    except Exception:
                        logger.warning(
                            "Model quantization failed; using non-quantized model instead.",
                            exc_info=True,
                        )

                self._tokenizer = tokenizer
                self._model = model
            except Exception as exc:
                raise ClassificationError("Unable to initialize classifier model") from exc

        return self._tokenizer, self._model

    def warmup(self) -> None:
        self._load_model()

    @staticmethod
    def _normalize_labels(labels: list[str]) -> list[str]:
        cleaned = [label.strip() for label in labels if isinstance(label, str) and label.strip()]
        return list(dict.fromkeys(cleaned))

    @staticmethod
    def _parse_label_id(value: Any) -> int | None:
        try:
            return int(value)
        except (TypeError, ValueError):
            return None

    @staticmethod
    def _extract_task_specific_entailment_id(model: Any) -> int | None:
        task_specific_params = getattr(model.config, "task_specific_params", {}) or {}
        if not isinstance(task_specific_params, dict):
            return None

        zero_shot_params = task_specific_params.get("zero-shot-classification", {})
        if not isinstance(zero_shot_params, dict):
            return None

        return ClassifierService._parse_label_id(zero_shot_params.get("entailment_id"))

    @staticmethod
    def _has_generic_label_names(model: Any) -> bool:
        label2id = getattr(model.config, "label2id", {}) or {}
        id2label = getattr(model.config, "id2label", {}) or {}

        labels: list[str] = []
        labels.extend(label for label in label2id.keys() if isinstance(label, str))
        labels.extend(label for label in id2label.values() if isinstance(label, str))
        if not labels:
            return False

        return all(label.lower().startswith("label_") for label in labels)

    @staticmethod
    def _resolve_entailment_id(model: Any) -> int:
        label2id = getattr(model.config, "label2id", {}) or {}
        for label, label_id in label2id.items():
            if isinstance(label, str) and label.lower().startswith("entail"):
                parsed = ClassifierService._parse_label_id(label_id)
                if parsed is not None:
                    return parsed

        id2label = getattr(model.config, "id2label", {}) or {}
        for label_id, label in id2label.items():
            if isinstance(label, str) and label.lower().startswith("entail"):
                parsed = ClassifierService._parse_label_id(label_id)
                if parsed is not None:
                    return parsed

        task_specific_entailment_id = ClassifierService._extract_task_specific_entailment_id(model)
        if task_specific_entailment_id is not None:
            return task_specific_entailment_id

        if settings.classifier_entailment_label_id is not None:
            return settings.classifier_entailment_label_id

        num_labels = ClassifierService._parse_label_id(getattr(model.config, "num_labels", None))
        if num_labels == 3 and (
            ClassifierService._has_generic_label_names(model) or (not label2id and not id2label)
        ):
            logger.warning(
                "Falling back to entailment label id 2 because model config labels are generic or missing "
                "and no explicit entailment mapping was found. Set CLASSIFIER_ENTAILMENT_LABEL_ID "
                "to override this behavior."
            )
            return 2

        raise ClassificationError(
            "Classifier model is missing an entailment label mapping. "
            "Set CLASSIFIER_ENTAILMENT_LABEL_ID in the environment when the model config "
            "does not expose an entailment label."
        )

    def classify(self, text: str, labels: list[str]) -> str:
        candidate_labels = self._normalize_labels(labels)
        if not candidate_labels:
            raise ClassificationError("No labels configured")

        tokenizer, model = self._load_model()
        entailment_id = self._resolve_entailment_id(model)

        try:
            sequence_pairs = [[text, self._HYPOTHESIS_TEMPLATE.format(label)] for label in candidate_labels]
            inputs = tokenizer(
                sequence_pairs,
                padding=True,
                truncation="only_first",
                return_tensors="pt",
            )

            with torch.no_grad():
                logits = model(**inputs).logits

            if logits.ndim != 2:
                raise ClassificationError("Classifier returned unexpected logits shape")
            if entailment_id < 0 or entailment_id >= logits.shape[-1]:
                raise ClassificationError("Entailment label index is out of range for classifier output")

            entailment_logits = logits[:, entailment_id]
            best_index = int(torch.argmax(entailment_logits).item())
            return candidate_labels[best_index]
        except Exception as exc:
            raise ClassificationError("Classifier prediction failed") from exc

classifier_service = ClassifierService()