|
|
from __future__ import annotations |
|
|
|
|
|
from contextlib import nullcontext |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from sentence_transformers import SparseEncoder |
|
|
|
|
|
|
|
|
|
|
|
class ForwardMethodWrapper: |
|
|
def __init__(self, model, is_inference: bool = True): |
|
|
self.model = model |
|
|
self.original_forward = model.forward |
|
|
self.is_inference = is_inference |
|
|
self.outputs = [] |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
|
|
with torch.inference_mode() if self.is_inference else nullcontext(): |
|
|
output = self.original_forward(*args, **kwargs) |
|
|
self.outputs.append(output) |
|
|
return output |
|
|
|
|
|
def reset(self): |
|
|
self.outputs = [] |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
["is_inference", "expected_keys"], |
|
|
[ |
|
|
( |
|
|
False, |
|
|
{ |
|
|
"input_ids", |
|
|
"attention_mask", |
|
|
"token_type_ids", |
|
|
"token_embeddings", |
|
|
"sentence_embedding", |
|
|
"sentence_embedding_backbone", |
|
|
"sentence_embedding_encoded", |
|
|
"sentence_embedding_encoded_4k", |
|
|
"auxiliary_embedding", |
|
|
"decoded_embedding_k", |
|
|
"decoded_embedding_4k", |
|
|
"decoded_embedding_aux", |
|
|
"decoded_embedding_k_pre_bias", |
|
|
}, |
|
|
), |
|
|
(True, {"input_ids", "attention_mask", "token_type_ids", "token_embeddings", "sentence_embedding"}), |
|
|
], |
|
|
) |
|
|
def test_csr_outputs(csr_bert_tiny_model: SparseEncoder, is_inference: bool, expected_keys: set) -> None: |
|
|
model = csr_bert_tiny_model |
|
|
|
|
|
|
|
|
wrapper = ForwardMethodWrapper(model, is_inference=is_inference) |
|
|
model.forward = wrapper |
|
|
|
|
|
|
|
|
inputs = model.tokenize(["This is a test sentence."]) |
|
|
inputs = {key: value.to(model.device) for key, value in inputs.items()} |
|
|
model(inputs) |
|
|
|
|
|
|
|
|
assert set(wrapper.outputs[0].keys()) == expected_keys |
|
|
|
|
|
|