| from __future__ import annotations |
|
|
| from contextlib import nullcontext |
|
|
| import pytest |
| import torch |
|
|
| from sentence_transformers import SparseEncoder |
|
|
|
|
| |
| class ForwardMethodWrapper: |
| def __init__(self, model, is_inference: bool = True): |
| self.model = model |
| self.original_forward = model.forward |
| self.is_inference = is_inference |
| self.outputs = [] |
|
|
| def __call__(self, *args, **kwargs): |
| |
| with torch.inference_mode() if self.is_inference else nullcontext(): |
| output = self.original_forward(*args, **kwargs) |
| self.outputs.append(output) |
| return output |
|
|
| def reset(self): |
| self.outputs = [] |
|
|
|
|
| @pytest.mark.parametrize( |
| ["is_inference", "expected_keys"], |
| [ |
| ( |
| False, |
| { |
| "input_ids", |
| "attention_mask", |
| "token_type_ids", |
| "token_embeddings", |
| "sentence_embedding", |
| "sentence_embedding_backbone", |
| "sentence_embedding_encoded", |
| "sentence_embedding_encoded_4k", |
| "auxiliary_embedding", |
| "decoded_embedding_k", |
| "decoded_embedding_4k", |
| "decoded_embedding_aux", |
| "decoded_embedding_k_pre_bias", |
| "modality", |
| }, |
| ), |
| ( |
| True, |
| {"input_ids", "attention_mask", "token_type_ids", "token_embeddings", "sentence_embedding", "modality"}, |
| ), |
| ], |
| ) |
| def test_csr_outputs(csr_bert_tiny_model: SparseEncoder, is_inference: bool, expected_keys: set) -> None: |
| model = csr_bert_tiny_model |
|
|
| |
| wrapper = ForwardMethodWrapper(model, is_inference=is_inference) |
| model.forward = wrapper |
|
|
| |
| inputs = model.preprocess(["This is a test sentence."]) |
| inputs = { |
| key: value.to(model.device) if isinstance(value, torch.Tensor) else value for key, value in inputs.items() |
| } |
| model(inputs) |
|
|
| |
| assert set(wrapper.outputs[0].keys()) == expected_keys |
| |
|
|