Spaces:
Sleeping
Sleeping
| from types import SimpleNamespace | |
| import torch | |
| import app.services.classifier_service as classifier_module | |
| class _FakeTokenizer: | |
| def __call__(self, sequence_pairs, padding, truncation, return_tensors): | |
| batch_size = len(sequence_pairs) | |
| return { | |
| "input_ids": torch.ones((batch_size, 2), dtype=torch.long), | |
| "attention_mask": torch.ones((batch_size, 2), dtype=torch.long), | |
| } | |
| class _FakeInferenceModel: | |
| def __init__(self, logits: torch.Tensor, config: SimpleNamespace | None = None) -> None: | |
| self._logits = logits | |
| self.config = config or SimpleNamespace( | |
| label2id={"CONTRADICTION": 0, "ENTAILMENT": 1}, | |
| id2label={0: "CONTRADICTION", 1: "ENTAILMENT"}, | |
| ) | |
| def __call__(self, **kwargs): | |
| return SimpleNamespace(logits=self._logits) | |
| class _FakeLoadModel: | |
| def __init__(self) -> None: | |
| self.config = SimpleNamespace( | |
| label2id={"CONTRADICTION": 0, "ENTAILMENT": 1}, | |
| id2label={0: "CONTRADICTION", 1: "ENTAILMENT"}, | |
| ) | |
| def eval(self): | |
| return self | |
| def to(self, device): | |
| return self | |
| def test_classify_uses_runtime_candidate_labels(monkeypatch): | |
| service = classifier_module.ClassifierService() | |
| tokenizer = _FakeTokenizer() | |
| model = _FakeInferenceModel( | |
| logits=torch.tensor( | |
| [ | |
| [3.2, 0.4], # finance -> low entailment | |
| [0.3, 4.1], # sport -> highest entailment | |
| [1.5, 1.9], # politics -> second-best entailment | |
| ] | |
| ) | |
| ) | |
| monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model)) | |
| predicted = service.classify( | |
| "This article discusses the latest football transfer strategy.", | |
| ["finance", "sport", "politics"], | |
| ) | |
| assert predicted == "sport" | |
| def test_classify_uses_task_specific_entailment_id_when_label_names_are_generic(monkeypatch): | |
| service = classifier_module.ClassifierService() | |
| tokenizer = _FakeTokenizer() | |
| model = _FakeInferenceModel( | |
| logits=torch.tensor( | |
| [ | |
| [1.8, 0.3, 0.4], # finance -> low entailment | |
| [0.4, 0.7, 3.7], # sport -> highest entailment | |
| ] | |
| ), | |
| config=SimpleNamespace( | |
| label2id={"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2}, | |
| id2label={0: "LABEL_0", 1: "LABEL_1", 2: "LABEL_2"}, | |
| task_specific_params={"zero-shot-classification": {"entailment_id": 2}}, | |
| num_labels=3, | |
| ), | |
| ) | |
| monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model)) | |
| monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None) | |
| predicted = service.classify( | |
| "The story is mostly about football transfers.", | |
| ["finance", "sport"], | |
| ) | |
| assert predicted == "sport" | |
| def test_classify_uses_explicit_entailment_id_setting_when_mapping_is_missing(monkeypatch): | |
| service = classifier_module.ClassifierService() | |
| tokenizer = _FakeTokenizer() | |
| model = _FakeInferenceModel( | |
| logits=torch.tensor( | |
| [ | |
| [2.0, 0.3], # finance -> low entailment | |
| [0.2, 3.4], # sport -> highest entailment | |
| ] | |
| ), | |
| config=SimpleNamespace( | |
| label2id={"NEGATIVE": 0, "POSITIVE": 1}, | |
| id2label={0: "NEGATIVE", 1: "POSITIVE"}, | |
| num_labels=2, | |
| ), | |
| ) | |
| monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model)) | |
| monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", 1) | |
| predicted = service.classify( | |
| "The story is mostly about football transfers.", | |
| ["finance", "sport"], | |
| ) | |
| assert predicted == "sport" | |
| def test_classify_falls_back_to_mnli_entailment_index_for_generic_three_label_configs(monkeypatch): | |
| service = classifier_module.ClassifierService() | |
| tokenizer = _FakeTokenizer() | |
| model = _FakeInferenceModel( | |
| logits=torch.tensor( | |
| [ | |
| [2.3, 0.6, 0.8], # finance -> low entailment | |
| [0.4, 0.8, 3.9], # sport -> highest entailment | |
| ] | |
| ), | |
| config=SimpleNamespace( | |
| label2id={"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2}, | |
| id2label={0: "LABEL_0", 1: "LABEL_1", 2: "LABEL_2"}, | |
| num_labels=3, | |
| ), | |
| ) | |
| monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model)) | |
| monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None) | |
| predicted = service.classify( | |
| "The story is mostly about football transfers.", | |
| ["finance", "sport"], | |
| ) | |
| assert predicted == "sport" | |
| def test_classify_falls_back_to_mnli_entailment_index_for_missing_label_mapping(monkeypatch): | |
| service = classifier_module.ClassifierService() | |
| tokenizer = _FakeTokenizer() | |
| model = _FakeInferenceModel( | |
| logits=torch.tensor( | |
| [ | |
| [1.8, 0.4, 0.5], # finance -> low entailment | |
| [0.2, 0.5, 3.6], # sport -> highest entailment | |
| ] | |
| ), | |
| config=SimpleNamespace( | |
| label2id={}, | |
| id2label={}, | |
| num_labels=3, | |
| ), | |
| ) | |
| monkeypatch.setattr(service, "_load_model", lambda: (tokenizer, model)) | |
| monkeypatch.setattr(classifier_module.settings, "classifier_entailment_label_id", None) | |
| predicted = service.classify( | |
| "The story is mostly about football transfers.", | |
| ["finance", "sport"], | |
| ) | |
| assert predicted == "sport" | |
| def test_model_quantization_falls_back_to_non_quantized_model(monkeypatch): | |
| service = classifier_module.ClassifierService() | |
| fake_model = _FakeLoadModel() | |
| fake_tokenizer = object() | |
| monkeypatch.setattr( | |
| classifier_module.AutoTokenizer, | |
| "from_pretrained", | |
| lambda *args, **kwargs: fake_tokenizer, | |
| ) | |
| monkeypatch.setattr( | |
| classifier_module.AutoModelForSequenceClassification, | |
| "from_pretrained", | |
| lambda *args, **kwargs: fake_model, | |
| ) | |
| monkeypatch.setattr(classifier_module.settings, "enable_model_quantization", True) | |
| def _raise_quantization_error(*args, **kwargs): | |
| raise RuntimeError("quantization backend unavailable") | |
| monkeypatch.setattr( | |
| classifier_module.torch.ao.quantization, | |
| "quantize_dynamic", | |
| _raise_quantization_error, | |
| ) | |
| _, loaded_model = service._load_model() | |
| assert loaded_model is fake_model | |