lsmpp's picture
Add files using upload-large-folder tool
bd33eac verified
from __future__ import annotations
import tempfile
from contextlib import nullcontext
from pathlib import Path
import pytest
import torch
from datasets import DatasetDict
from tokenizers.processors import TemplateProcessing
from sentence_transformers import SparseEncoder, SparseEncoderTrainer, SparseEncoderTrainingArguments
from sentence_transformers.sparse_encoder import losses
from sentence_transformers.util import is_datasets_available, is_training_available
from tests.utils import SafeTemporaryDirectory
if is_datasets_available():
from datasets import Dataset
if not is_training_available():
pytest.skip(
reason='Sentence Transformers was not installed with the `["train"]` extra.',
allow_module_level=True,
)
@pytest.fixture()
def dummy_sparse_encoder_for_trainer() -> SparseEncoder:
return SparseEncoder("sparse-encoder-testing/splade-bert-tiny-nq")
@pytest.fixture
def dummy_train_eval_datasets_for_trainer() -> tuple[Dataset, Dataset]:
# Create minimal datasets for trainer tests
train_data = {
"sentence1": [f"train_s1_{i}" for i in range(20)],
"sentence2": [f"train_s2_{i}" for i in range(20)],
"score": [float(i % 2) for i in range(20)],
}
eval_data = {
"sentence1": [f"eval_s1_{i}" for i in range(10)],
"sentence2": [f"eval_s2_{i}" for i in range(10)],
"score": [float(i % 2) for i in range(10)],
}
train_dataset = Dataset.from_dict(train_data)
eval_dataset = Dataset.from_dict(eval_data)
return train_dataset, eval_dataset
def test_model_card_reuse(dummy_sparse_encoder_for_trainer: SparseEncoder):
model = dummy_sparse_encoder_for_trainer
initial_card_text = model._model_card_text
SparseEncoderTrainer(
model=model,
loss=losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
document_regularizer_weight=3e-5,
query_regularizer_weight=5e-5,
),
)
with SafeTemporaryDirectory() as tmp_folder:
model_path = Path(tmp_folder) / "sparse_model_local"
model.save_pretrained(str(model_path))
with open(model_path / "README.md", encoding="utf8") as f:
trained_model_card_text = f.read()
if initial_card_text:
assert trained_model_card_text != initial_card_text
else:
assert trained_model_card_text is not None # Should have created one
@pytest.mark.parametrize("streaming", [False, True])
def test_trainer(
dummy_sparse_encoder_for_trainer: SparseEncoder,
dummy_train_eval_datasets_for_trainer: tuple[Dataset, Dataset],
streaming: bool,
) -> None:
model = dummy_sparse_encoder_for_trainer
train_dataset, eval_dataset = dummy_train_eval_datasets_for_trainer
context = nullcontext()
if streaming:
train_dataset = train_dataset.to_iterable_dataset()
eval_dataset = eval_dataset.to_iterable_dataset()
original_model_params = [p.clone() for p in model.parameters()]
loss = losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
document_regularizer_weight=3e-5,
query_regularizer_weight=5e-5,
)
with tempfile.TemporaryDirectory() as temp_dir:
args = SparseEncoderTrainingArguments(
output_dir=str(temp_dir),
max_steps=2,
eval_strategy="steps", # Changed from eval_steps to eval_strategy
eval_steps=2,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
logging_steps=1,
remove_unused_columns=False, # Important for custom dict datasets
)
with context: # context is nullcontext unless streaming causes issues not caught here
trainer = SparseEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
)
trainer.train()
if isinstance(context, nullcontext):
# Check if model parameters have changed after training
model_changed = False
for p_orig, p_new in zip(original_model_params, model.parameters()):
if not torch.equal(p_orig, p_new):
model_changed = True
break
assert model_changed, "Model parameters should have changed after training."
# Simple check to ensure prediction works after training
try:
model.encode(["Test sentence after training."])
except Exception as e:
pytest.fail(f"Encoding failed after training: {e}")
@pytest.mark.parametrize("has_bos_token", [True, False])
@pytest.mark.parametrize("has_eos_token", [True, False])
def test_data_collator(
csr_bert_tiny_model: SparseEncoder,
stsb_dataset_dict: DatasetDict,
has_bos_token: bool,
has_eos_token: bool,
tmp_path: Path,
) -> None:
# Test that the data collator correctly recognizes whether the tokenizer has an SEP/EOS token
model = csr_bert_tiny_model
# We need to set this to False, otherwise the prompt length wont be needed:
model.set_pooling_include_prompt(False)
dummy_bos_token_id = 400
dummy_eos_token_id = 500
model.tokenizer.cls_token_id = dummy_bos_token_id if has_bos_token else None
model.tokenizer.sep_token_id = dummy_eos_token_id if has_eos_token else None
if has_bos_token:
if has_eos_token:
model.tokenizer._tokenizer.post_processor = TemplateProcessing(
single="[CLS] $0 [SEP]",
special_tokens=[
("[CLS]", dummy_bos_token_id),
("[SEP]", dummy_eos_token_id),
],
)
else:
model.tokenizer._tokenizer.post_processor = TemplateProcessing(
single="[CLS] $0",
special_tokens=[("[CLS]", dummy_bos_token_id)],
)
else:
if has_eos_token:
model.tokenizer._tokenizer.post_processor = TemplateProcessing(
single="$0 [SEP]",
special_tokens=[("[SEP]", dummy_eos_token_id)],
)
else:
model.tokenizer._tokenizer.post_processor = TemplateProcessing(
single="$0",
special_tokens=[],
)
# Check that we can update the tokenizer in this way
if has_eos_token:
assert model.tokenize(["dummy text"])["input_ids"].flatten()[-1] == dummy_eos_token_id
else:
assert model.tokenize(["dummy text"])["input_ids"].flatten()[-1] != dummy_eos_token_id
if has_bos_token:
assert model.tokenize(["dummy text"])["input_ids"].flatten()[0] == dummy_bos_token_id
else:
assert model.tokenize(["dummy text"])["input_ids"].flatten()[0] != dummy_bos_token_id
train_dataset = stsb_dataset_dict["train"].select(range(10))
eval_dataset = stsb_dataset_dict["validation"].select(range(10))
loss = losses.CSRLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
)
args = SparseEncoderTrainingArguments(
output_dir=tmp_path,
max_steps=2,
eval_steps=2,
eval_strategy="steps",
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
prompts="Prompt: ", # Single prompt to all columns and all datasets
)
trainer = SparseEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
)
trainer.train()
# Check that the data collator correctly recognizes the prompt length
only_prompt_length = len(model.tokenizer(["Prompt: "], add_special_tokens=False)["input_ids"][0])
if has_bos_token:
only_prompt_length += 1
assert trainer.data_collator._prompt_length_mapping == {("Prompt: ", None): only_prompt_length}