lsmpp's picture
Add files using upload-large-folder tool
bd33eac verified
from __future__ import annotations
from contextlib import nullcontext
import pytest
import torch
from sentence_transformers import SparseEncoder
# Create a wrapper to measure outputs of the forward method
class ForwardMethodWrapper:
def __init__(self, model, is_inference: bool = True):
self.model = model
self.original_forward = model.forward
self.is_inference = is_inference
self.outputs = []
def __call__(self, *args, **kwargs):
# Set the model to training mode if is_train is True
with torch.inference_mode() if self.is_inference else nullcontext():
output = self.original_forward(*args, **kwargs)
self.outputs.append(output)
return output
def reset(self):
self.outputs = []
@pytest.mark.parametrize(
["is_inference", "expected_keys"],
[
(
False,
{
"input_ids",
"attention_mask",
"token_type_ids",
"token_embeddings",
"sentence_embedding",
"sentence_embedding_backbone",
"sentence_embedding_encoded",
"sentence_embedding_encoded_4k",
"auxiliary_embedding",
"decoded_embedding_k",
"decoded_embedding_4k",
"decoded_embedding_aux",
"decoded_embedding_k_pre_bias",
},
),
(True, {"input_ids", "attention_mask", "token_type_ids", "token_embeddings", "sentence_embedding"}),
],
)
def test_csr_outputs(csr_bert_tiny_model: SparseEncoder, is_inference: bool, expected_keys: set) -> None:
model = csr_bert_tiny_model
# Create the wrapper and replace the forward method
wrapper = ForwardMethodWrapper(model, is_inference=is_inference)
model.forward = wrapper
# Run the encode method which should call forward internally
inputs = model.tokenize(["This is a test sentence."])
inputs = {key: value.to(model.device) for key, value in inputs.items()}
model(inputs)
# Check that the model was called in the correct mode, and that the outputs contain the expected keys
assert set(wrapper.outputs[0].keys()) == expected_keys
# We don't have to restore the original forward method, as the model will not be reused