classifier-general / tests /test_classifier_service.py
AyoubChLin's picture
[REF] api documentation
2571402
from types import SimpleNamespace
import torch
import app.services.classifier_service as classifier_module
class _FakeTokenizer:
def __call__(self, sequence_pairs, padding, truncation, return_tensors):
batch_size = len(sequence_pairs)
return {
"input_ids": torch.ones((batch_size, 2), dtype=torch.long),
"attention_mask": torch.ones((batch_size, 2), dtype=torch.long),
}
class _FakeInferenceModel:
def __init__(self, logits: torch.Tensor, config: SimpleNamespace | None = None) -> None:
self._logits = logits
self.config = config or SimpleNamespace(
label2id={"CONTRADICTION": 0, "ENTAILMENT": 1},
id2label={0: "CONTRADICTION", 1: "ENTAILMENT"},
)
def __call__(self, **kwargs):
return SimpleNamespace(logits=self._logits)
class _FakeLoadModel:
def __init__(self) -> None:
self.config = SimpleNamespace(
label2id={"CONTRADICTION": 0, "ENTAILMENT": 1},
id2label={0: "CONTRADICTION", 1: "ENTAILMENT"},
)
def eval(self):
return self
def to(self, device):
return self
def test_classify_uses_runtime_candidate_labels(monkeypatch):
service = classifier_module.ClassifierService()
tokenizer = _FakeTokenizer()
model = _FakeInferenceModel(
logits=torch.tensor(
[
[3.2, 0.4], # finance -> low entailment
[0.3, 4.1], # sport -> highest entailment
[1.5, 1.9], # politics -> second-best entailment
]
)
)
monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
predicted = service.classify(
"This article discusses the latest football transfer strategy.",
["finance", "sport", "politics"],
)
assert predicted == "sport"
def test_classify_uses_task_specific_entailment_id_when_label_names_are_generic(monkeypatch):
service = classifier_module.ClassifierService()
tokenizer = _FakeTokenizer()
model = _FakeInferenceModel(
logits=torch.tensor(
[
[1.8, 0.3, 0.4], # finance -> low entailment
[0.4, 0.7, 3.7], # sport -> highest entailment
]
),
config=SimpleNamespace(
label2id={"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2},
id2label={0: "LABEL_0", 1: "LABEL_1", 2: "LABEL_2"},
task_specific_params={"zero-shot-classification": {"entailment_id": 2}},
num_labels=3,
),
)
monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None)
predicted = service.classify(
"The story is mostly about football transfers.",
["finance", "sport"],
)
assert predicted == "sport"
def test_classify_uses_explicit_entailment_id_setting_when_mapping_is_missing(monkeypatch):
service = classifier_module.ClassifierService()
tokenizer = _FakeTokenizer()
model = _FakeInferenceModel(
logits=torch.tensor(
[
[2.0, 0.3], # finance -> low entailment
[0.2, 3.4], # sport -> highest entailment
]
),
config=SimpleNamespace(
label2id={"NEGATIVE": 0, "POSITIVE": 1},
id2label={0: "NEGATIVE", 1: "POSITIVE"},
num_labels=2,
),
)
monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", 1)
predicted = service.classify(
"The story is mostly about football transfers.",
["finance", "sport"],
)
assert predicted == "sport"
def test_classify_falls_back_to_mnli_entailment_index_for_generic_three_label_configs(monkeypatch):
service = classifier_module.ClassifierService()
tokenizer = _FakeTokenizer()
model = _FakeInferenceModel(
logits=torch.tensor(
[
[2.3, 0.6, 0.8], # finance -> low entailment
[0.4, 0.8, 3.9], # sport -> highest entailment
]
),
config=SimpleNamespace(
label2id={"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2},
id2label={0: "LABEL_0", 1: "LABEL_1", 2: "LABEL_2"},
num_labels=3,
),
)
monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None)
predicted = service.classify(
"The story is mostly about football transfers.",
["finance", "sport"],
)
assert predicted == "sport"
def test_classify_falls_back_to_mnli_entailment_index_for_missing_label_mapping(monkeypatch):
service = classifier_module.ClassifierService()
tokenizer = _FakeTokenizer()
model = _FakeInferenceModel(
logits=torch.tensor(
[
[1.8, 0.4, 0.5], # finance -> low entailment
[0.2, 0.5, 3.6], # sport -> highest entailment
]
),
config=SimpleNamespace(
label2id={},
id2label={},
num_labels=3,
),
)
monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model))
monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None)
predicted = service.classify(
"The story is mostly about football transfers.",
["finance", "sport"],
)
assert predicted == "sport"
def test_model_quantization_falls_back_to_non_quantized_model(monkeypatch):
service = classifier_module.ClassifierService()
fake_model = _FakeLoadModel()
fake_tokenizer = object()
monkeypatch.setattr(
classifier_module.AutoTokenizer,
"from_pretrained",
lambda *args, **kwargs: fake_tokenizer,
)
monkeypatch.setattr(
classifier_module.AutoModelForSequenceClassification,
"from_pretrained",
lambda *args, **kwargs: fake_model,
)
monkeypatch.setattr(classifier_module.settings, "enable_model_quantization", True)
def _raise_quantization_error(*args, **kwargs):
raise RuntimeError("quantization backend unavailable")
monkeypatch.setattr(
classifier_module.torch.ao.quantization,
"quantize_dynamic",
_raise_quantization_error,
)
_, loaded_model = service._load_model()
assert loaded_model is fake_model