Spaces:
Sleeping
Sleeping
File size: 6,641 Bytes
2d0ef3b 50231a8 59b46a2 50231a8 2d0ef3b 50231a8 2d0ef3b 50231a8 59b46a2 efddb2f 59b46a2 efddb2f 2d0ef3b 59b46a2 2d0ef3b 59b46a2 2d0ef3b efddb2f 59b46a2 50231a8 2d0ef3b 2571402 2d0ef3b 2571402 2d0ef3b 2571402 2d0ef3b 50231a8 2d0ef3b 50231a8 59b46a2 2d0ef3b 59b46a2 50231a8 2d0ef3b 59b46a2 2d0ef3b 59b46a2 2d0ef3b 59b46a2 2d0ef3b 50231a8 efddb2f 50231a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | import logging
from typing import Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from app.core.config import settings
from app.core.exceptions import ClassificationError
logger = logging.getLogger(__name__)
class ClassifierService:
_HYPOTHESIS_TEMPLATE = "This text is about {}."
def __init__(self) -> None:
self._tokenizer: Any | None = None
self._model: Any | None = None
def _load_model(self) -> tuple[Any, Any]:
if self._tokenizer is None or self._model is None:
try:
tokenizer = AutoTokenizer.from_pretrained(
settings.classifier_model,
token=settings.huggingface_token,
)
model = AutoModelForSequenceClassification.from_pretrained(
settings.classifier_model,
token=settings.huggingface_token,
)
model.eval()
model.to("cpu")
if settings.enable_model_quantization:
try:
# Dynamic INT8 quantization for CPU inference.
quantized_model = torch.ao.quantization.quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8,
)
model = quantized_model
except Exception:
logger.warning(
"Model quantization failed; using non-quantized model instead.",
exc_info=True,
)
self._tokenizer = tokenizer
self._model = model
except Exception as exc:
raise ClassificationError("Unable to initialize classifier model") from exc
return self._tokenizer, self._model
def warmup(self) -> None:
self._load_model()
@staticmethod
def _normalize_labels(labels: list[str]) -> list[str]:
cleaned = [label.strip() for label in labels if isinstance(label, str) and label.strip()]
return list(dict.fromkeys(cleaned))
@staticmethod
def _parse_label_id(value: Any) -> int | None:
try:
return int(value)
except (TypeError, ValueError):
return None
@staticmethod
def _extract_task_specific_entailment_id(model: Any) -> int | None:
task_specific_params = getattr(model.config, "task_specific_params", {}) or {}
if not isinstance(task_specific_params, dict):
return None
zero_shot_params = task_specific_params.get("zero-shot-classification", {})
if not isinstance(zero_shot_params, dict):
return None
return ClassifierService._parse_label_id(zero_shot_params.get("entailment_id"))
@staticmethod
def _has_generic_label_names(model: Any) -> bool:
label2id = getattr(model.config, "label2id", {}) or {}
id2label = getattr(model.config, "id2label", {}) or {}
labels: list[str] = []
labels.extend(label for label in label2id.keys() if isinstance(label, str))
labels.extend(label for label in id2label.values() if isinstance(label, str))
if not labels:
return False
return all(label.lower().startswith("label_") for label in labels)
@staticmethod
def _resolve_entailment_id(model: Any) -> int:
label2id = getattr(model.config, "label2id", {}) or {}
for label, label_id in label2id.items():
if isinstance(label, str) and label.lower().startswith("entail"):
parsed = ClassifierService._parse_label_id(label_id)
if parsed is not None:
return parsed
id2label = getattr(model.config, "id2label", {}) or {}
for label_id, label in id2label.items():
if isinstance(label, str) and label.lower().startswith("entail"):
parsed = ClassifierService._parse_label_id(label_id)
if parsed is not None:
return parsed
task_specific_entailment_id = ClassifierService._extract_task_specific_entailment_id(model)
if task_specific_entailment_id is not None:
return task_specific_entailment_id
if settings.classifier_entailment_label_id is not None:
return settings.classifier_entailment_label_id
num_labels = ClassifierService._parse_label_id(getattr(model.config, "num_labels", None))
if num_labels == 3 and (
ClassifierService._has_generic_label_names(model) or (not label2id and not id2label)
):
logger.warning(
"Falling back to entailment label id 2 because model config labels are generic or missing "
"and no explicit entailment mapping was found. Set CLASSIFIER_ENTAILMENT_LABEL_ID "
"to override this behavior."
)
return 2
raise ClassificationError(
"Classifier model is missing an entailment label mapping. "
"Set CLASSIFIER_ENTAILMENT_LABEL_ID in the environment when the model config "
"does not expose an entailment label."
)
def classify(self, text: str, labels: list[str]) -> str:
candidate_labels = self._normalize_labels(labels)
if not candidate_labels:
raise ClassificationError("No labels configured")
tokenizer, model = self._load_model()
entailment_id = self._resolve_entailment_id(model)
try:
sequence_pairs = [[text, self._HYPOTHESIS_TEMPLATE.format(label)] for label in candidate_labels]
inputs = tokenizer(
sequence_pairs,
padding=True,
truncation="only_first",
return_tensors="pt",
)
with torch.no_grad():
logits = model(**inputs).logits
if logits.ndim != 2:
raise ClassificationError("Classifier returned unexpected logits shape")
if entailment_id < 0 or entailment_id >= logits.shape[-1]:
raise ClassificationError("Entailment label index is out of range for classifier output")
entailment_logits = logits[:, entailment_id]
best_index = int(torch.argmax(entailment_logits).item())
return candidate_labels[best_index]
except Exception as exc:
raise ClassificationError("Classifier prediction failed") from exc
classifier_service = ClassifierService()
|