File size: 2,330 Bytes
bd33eac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import annotations

from contextlib import nullcontext

import pytest
import torch

from sentence_transformers import SparseEncoder


# Create a wrapper to measure outputs of the forward method
class ForwardMethodWrapper:
    def __init__(self, model, is_inference: bool = True):
        self.model = model
        self.original_forward = model.forward
        self.is_inference = is_inference
        self.outputs = []

    def __call__(self, *args, **kwargs):
        # Set the model to training mode if is_train is True
        with torch.inference_mode() if self.is_inference else nullcontext():
            output = self.original_forward(*args, **kwargs)
        self.outputs.append(output)
        return output

    def reset(self):
        self.outputs = []


@pytest.mark.parametrize(
    ["is_inference", "expected_keys"],
    [
        (
            False,
            {
                "input_ids",
                "attention_mask",
                "token_type_ids",
                "token_embeddings",
                "sentence_embedding",
                "sentence_embedding_backbone",
                "sentence_embedding_encoded",
                "sentence_embedding_encoded_4k",
                "auxiliary_embedding",
                "decoded_embedding_k",
                "decoded_embedding_4k",
                "decoded_embedding_aux",
                "decoded_embedding_k_pre_bias",
            },
        ),
        (True, {"input_ids", "attention_mask", "token_type_ids", "token_embeddings", "sentence_embedding"}),
    ],
)
def test_csr_outputs(csr_bert_tiny_model: SparseEncoder, is_inference: bool, expected_keys: set) -> None:
    model = csr_bert_tiny_model

    # Create the wrapper and replace the forward method
    wrapper = ForwardMethodWrapper(model, is_inference=is_inference)
    model.forward = wrapper

    # Run the encode method which should call forward internally
    inputs = model.tokenize(["This is a test sentence."])
    inputs = {key: value.to(model.device) for key, value in inputs.items()}
    model(inputs)

    # Check that the model was called in the correct mode, and that the outputs contain the expected keys
    assert set(wrapper.outputs[0].keys()) == expected_keys
    # We don't have to restore the original forward method, as the model will not be reused