File size: 11,500 Bytes
bd33eac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | from __future__ import annotations
from pathlib import Path
import pytest
from sentence_transformers import SparseEncoderTrainer, SparseEncoderTrainingArguments
from sentence_transformers.model_card import generate_model_card
from sentence_transformers.sparse_encoder import losses
from sentence_transformers.util import is_datasets_available, is_training_available
if is_datasets_available():
from datasets import Dataset, DatasetDict
if not is_training_available():
pytest.skip(
reason='Sentence Transformers was not installed with the `["train"]` extra.',
allow_module_level=True,
)
@pytest.fixture(scope="session")
def dummy_dataset():
"""
Dummy dataset for testing purposes. The dataset looks as follows:
{
"anchor": ["anchor 1", "anchor 2", ..., "anchor 10"],
"positive": ["positive 1", "positive 2", ..., "positive 10"],
"negative": ["negative 1", "negative 2", ..., "negative 10"],
}
"""
return Dataset.from_dict(
{
"anchor": [f"anchor {i}" for i in range(1, 11)],
"positive": [f"positive {i}" for i in range(1, 11)],
"negative": [f"negative {i}" for i in range(1, 11)],
}
)
@pytest.mark.parametrize(
("model_fixture_name", "num_datasets", "expected_substrings"),
[
# 0 actually refers to just a single dataset
(
"splade_bert_tiny_model",
0,
[
"- sentence-transformers",
"- sparse-encoder",
"- sparse",
"- splade",
"This is a [SPLADE Sparse Encoder](https://www.sbert.net/docs/sparse_encoder/usage/usage.html) model finetuned from [sparse-encoder-testing/splade-bert-tiny-nq](https://huggingface.co/sparse-encoder-testing/splade-bert-tiny-nq)",
"**Maximum Sequence Length:** 512 tokens",
"**Output Dimensionality:** 30522 dimensions",
"**Similarity Function:** Dot Product",
"#### Unnamed Dataset",
"| details | <ul><li>min: 4 tokens</li><li>mean: 4.0 tokens</li><li>max: 4 tokens</li></ul> | <ul><li>min: 4 tokens</li><li>mean: 4.0 tokens</li><li>max: 4 tokens</li></ul> | <ul><li>min: 4 tokens</li><li>mean: 4.0 tokens</li><li>max: 4 tokens</li></ul> |",
" | <code>anchor 1</code> | <code>positive 1</code> | <code>negative 1</code> |",
"* Loss: [<code>SpladeLoss</code>](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:",
' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```',
],
),
(
"splade_bert_tiny_model",
1,
[
"This is a [SPLADE Sparse Encoder](https://www.sbert.net/docs/sparse_encoder/usage/usage.html) model finetuned from [sparse-encoder-testing/splade-bert-tiny-nq](https://huggingface.co/sparse-encoder-testing/splade-bert-tiny-nq) on the train_0 dataset using the [sentence-transformers](https://www.SBERT.net) library.",
"#### train_0",
"* Loss: [<code>SpladeLoss</code>](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:",
' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```',
],
),
(
"splade_bert_tiny_model",
2,
[
"This is a [SPLADE Sparse Encoder](https://www.sbert.net/docs/sparse_encoder/usage/usage.html) model finetuned from [sparse-encoder-testing/splade-bert-tiny-nq](https://huggingface.co/sparse-encoder-testing/splade-bert-tiny-nq) on the train_0 and train_1 datasets using the [sentence-transformers](https://www.SBERT.net) library.",
"#### train_0",
"#### train_1",
"* Loss: [<code>SpladeLoss</code>](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:",
' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```',
],
),
(
"splade_bert_tiny_model",
10,
[
"This is a [SPLADE Sparse Encoder](https://www.sbert.net/docs/sparse_encoder/usage/usage.html) model finetuned from [sparse-encoder-testing/splade-bert-tiny-nq](https://huggingface.co/sparse-encoder-testing/splade-bert-tiny-nq) on the train_0, train_1, train_2, train_3, train_4, train_5, train_6, train_7, train_8 and train_9 datasets using the [sentence-transformers](https://www.SBERT.net) library.",
"<details><summary>train_0</summary>", # We start using <details><summary> if we have more than 3 datasets
"#### train_0",
"</details>\n<details><summary>train_9</summary>",
"#### train_9",
"* Loss: [<code>SpladeLoss</code>](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:",
' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```',
],
),
# We start using "50 datasets" when the ", "-joined dataset name exceed 200 characters
(
"splade_bert_tiny_model",
50,
[
"This is a [SPLADE Sparse Encoder](https://www.sbert.net/docs/sparse_encoder/usage/usage.html) model finetuned from [sparse-encoder-testing/splade-bert-tiny-nq](https://huggingface.co/sparse-encoder-testing/splade-bert-tiny-nq) on 50 datasets using the [sentence-transformers](https://www.SBERT.net) library.",
"<details><summary>train_0</summary>",
"#### train_0",
"</details>\n<details><summary>train_49</summary>",
"#### train_49",
"* Loss: [<code>SpladeLoss</code>](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:",
' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```',
],
),
(
"csr_bert_tiny_model",
0,
[
"- sentence-transformers",
"- sparse-encoder",
"- sparse",
"- csr",
"This is a [CSR Sparse Encoder](https://www.sbert.net/docs/sparse_encoder/usage/usage.html) model finetuned from [sentence-transformers-testing/stsb-bert-tiny-safetensors](https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors) using the [sentence-transformers](https://www.SBERT.net) library.",
"**Maximum Sequence Length:** 512 tokens",
"**Output Dimensionality:** 512 dimensions",
"**Similarity Function:** Dot Product",
"#### Unnamed Dataset",
"| details | <ul><li>min: 4 tokens</li><li>mean: 4.0 tokens</li><li>max: 4 tokens</li></ul> | <ul><li>min: 4 tokens</li><li>mean: 4.0 tokens</li><li>max: 4 tokens</li></ul> | <ul><li>min: 4 tokens</li><li>mean: 4.0 tokens</li><li>max: 4 tokens</li></ul> |",
" | <code>anchor 1</code> | <code>positive 1</code> | <code>negative 1</code> |",
"* Loss: [<code>SpladeLoss</code>](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:",
' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```',
],
),
(
"inference_free_splade_bert_tiny_model",
0,
[
"- sentence-transformers",
"- sparse-encoder",
"- sparse",
"- asymmetric",
"- inference-free",
"- splade",
"This is a [Asymmetric Inference-free SPLADE Sparse Encoder](https://www.sbert.net/docs/sparse_encoder/usage/usage.html) model finetuned from [sparse-encoder-testing/inference-free-splade-bert-tiny-nq](https://huggingface.co/sparse-encoder-testing/inference-free-splade-bert-tiny-nq) using the [sentence-transformers](https://www.SBERT.net) library.",
"**Maximum Sequence Length:** 512 tokens",
"**Output Dimensionality:** 30522 dimensions",
"**Similarity Function:** Dot Product",
"#### Unnamed Dataset",
"| details | <ul><li>min: 4 tokens</li><li>mean: 4.0 tokens</li><li>max: 4 tokens</li></ul> | <ul><li>min: 4 tokens</li><li>mean: 4.0 tokens</li><li>max: 4 tokens</li></ul> | <ul><li>min: 4 tokens</li><li>mean: 4.0 tokens</li><li>max: 4 tokens</li></ul> |",
" | <code>anchor 1</code> | <code>positive 1</code> | <code>negative 1</code> |",
"* Loss: [<code>SpladeLoss</code>](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:",
' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```',
],
),
],
)
def test_model_card_base(
model_fixture_name: str,
dummy_dataset: Dataset,
num_datasets: int,
expected_substrings: list[str],
request: pytest.FixtureRequest,
tmp_path: Path,
) -> None:
model = request.getfixturevalue(model_fixture_name)
train_dataset = dummy_dataset
if num_datasets:
train_dataset = DatasetDict({f"train_{i}": train_dataset for i in range(num_datasets)})
loss = losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
query_regularizer_weight=5e-5, # Weight for query loss
document_regularizer_weight=3e-5, # Weight for document loss
)
args = SparseEncoderTrainingArguments(
output_dir=tmp_path,
router_mapping={"test": "query"} if "inference_free" in model_fixture_name else None,
)
# This adds data to model.model_card_data
SparseEncoderTrainer(
model,
args=args,
train_dataset=train_dataset,
loss=loss,
)
model_card = generate_model_card(model)
# For debugging purposes, we can save the model card to a file
# with open(f"test_model_card_{model_fixture_name}_{num_datasets}d.md", "w", encoding="utf8") as f:
# f.write(model_card)
for substring in expected_substrings:
assert substring in model_card
# We don't want to have two consecutive empty lines anywhere
assert "\n\n\n" not in model_card
|