Serkan007's picture
Sentence-Transformers ve E5-Large model aktarımı.
9bbba62 verified
from __future__ import annotations
import tempfile
from contextlib import nullcontext
from pathlib import Path
import pytest
import torch
from sentence_transformers import SparseEncoder, SparseEncoderTrainer, SparseEncoderTrainingArguments
from sentence_transformers.sparse_encoder import losses
from sentence_transformers.util import is_datasets_available, is_training_available
if is_datasets_available():
from datasets import Dataset, DatasetDict, IterableDatasetDict
if not is_training_available():
pytest.skip(
reason='Sentence Transformers was not installed with the `["train"]` extra.',
allow_module_level=True,
)
@pytest.fixture
def dummy_train_eval_datasets_for_trainer() -> tuple[Dataset, Dataset]:
# Create minimal datasets for trainer tests
train_data = {
"sentence1": [f"train_s1_{i}" for i in range(20)],
"sentence2": [f"train_s2_{i}" for i in range(20)],
"score": [float(i % 2) for i in range(20)],
}
eval_data = {
"sentence1": [f"eval_s1_{i}" for i in range(10)],
"sentence2": [f"eval_s2_{i}" for i in range(10)],
"score": [float(i % 2) for i in range(10)],
}
train_dataset = Dataset.from_dict(train_data)
eval_dataset = Dataset.from_dict(eval_data)
return train_dataset, eval_dataset
def test_model_card_reuse(splade_bert_tiny_model: SparseEncoder):
model = splade_bert_tiny_model
initial_card_text = model._model_card_text
SparseEncoderTrainer(
model=model,
loss=losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
document_regularizer_weight=3e-5,
query_regularizer_weight=5e-5,
),
)
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmp_folder:
model_path = Path(tmp_folder) / "sparse_model_local"
model.save_pretrained(str(model_path))
with open(model_path / "README.md", encoding="utf8") as f:
trained_model_card_text = f.read()
if initial_card_text:
assert trained_model_card_text != initial_card_text
else:
assert trained_model_card_text is not None # Should have created one
@pytest.mark.parametrize("streaming", [False, True])
def test_trainer(
splade_bert_tiny_model: SparseEncoder,
dummy_train_eval_datasets_for_trainer: tuple[Dataset, Dataset],
streaming: bool,
) -> None:
model = splade_bert_tiny_model
train_dataset, eval_dataset = dummy_train_eval_datasets_for_trainer
context = nullcontext()
if streaming:
train_dataset = train_dataset.to_iterable_dataset()
eval_dataset = eval_dataset.to_iterable_dataset()
original_model_params = [p.clone() for p in model.parameters()]
loss = losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
document_regularizer_weight=3e-5,
query_regularizer_weight=5e-5,
)
with tempfile.TemporaryDirectory() as temp_dir:
args = SparseEncoderTrainingArguments(
output_dir=str(temp_dir),
max_steps=2,
eval_strategy="steps", # Changed from eval_steps to eval_strategy
eval_steps=2,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
logging_steps=1,
remove_unused_columns=False, # Important for custom dict datasets
report_to=["none"],
)
with context: # context is nullcontext unless streaming causes issues not caught here
trainer = SparseEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
)
trainer.train()
if isinstance(context, nullcontext):
# Check if model parameters have changed after training
model_changed = False
for p_orig, p_new in zip(original_model_params, model.parameters()):
if not torch.equal(p_orig, p_new):
model_changed = True
break
assert model_changed, "Model parameters should have changed after training."
# Simple check to ensure prediction works after training
try:
model.encode(["Test sentence after training."])
except Exception as e:
pytest.fail(f"Encoding failed after training: {e}")
@pytest.mark.slow
@pytest.mark.parametrize("train_dict", [False, True])
@pytest.mark.parametrize("eval_dict", [False, True])
@pytest.mark.parametrize("loss_dict", [False, True])
@pytest.mark.parametrize("add_transform", [False, True])
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize(
"prompts",
[
None, # No prompt
"Prompt: ", # Single prompt to all columns and all datasets
{"stsb-1": "Prompt 1: ", "stsb-2": "Prompt 2: "}, # Different prompts for different datasets
{"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "}, # Different prompts for different columns
{
"stsb-1": {"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "},
"stsb-2": {"sentence1": "Prompt 3: ", "sentence2": "Prompt 4: "},
}, # Different prompts for different datasets and columns
],
)
def test_trainer_prompts(
splade_bert_tiny_model: SparseEncoder,
train_dict: bool,
eval_dict: bool,
loss_dict: bool,
add_transform: bool,
streaming: bool,
prompts: dict[str, dict[str, str]] | dict[str, str] | str | None,
):
if loss_dict and (not train_dict or not eval_dict):
pytest.skip(
"Skipping test because loss_dict=True requires train_dict=True and eval_dict=True; already tested via test_trainer."
)
model = splade_bert_tiny_model
train_dataset_1 = Dataset.from_dict(
{
"sentence1": ["train 1 sentence 1a", "train 1 sentence 1b"],
"sentence2": ["train 1 sentence 2a", "train 1 sentence 2b"],
}
)
train_dataset_2 = Dataset.from_dict(
{
"sentence1": ["train 2 sentence 1a", "train 2 sentence 1b"],
"sentence2": ["train 2 sentence 2a", "train 2 sentence 2b"],
}
)
eval_dataset_1 = Dataset.from_dict(
{
"sentence1": ["eval 1 sentence 1a", "eval 1 sentence 1b"],
"sentence2": ["eval 1 sentence 2a", "eval 1 sentence 2b"],
}
)
eval_dataset_2 = Dataset.from_dict(
{
"sentence1": ["eval 2 sentence 1a", "eval 2 sentence 1b"],
"sentence2": ["eval 2 sentence 2a", "eval 2 sentence 2b"],
}
)
loss = losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
document_regularizer_weight=3e-5,
query_regularizer_weight=5e-5,
)
tracked_texts = []
old_preprocess = model.preprocess
def preprocess_tracker(texts, prompt=None, **kwargs):
if prompt:
tracked_texts.extend([prompt + text for text in texts])
else:
tracked_texts.extend(texts)
return old_preprocess(texts, prompt=prompt, **kwargs)
model.preprocess = preprocess_tracker
if train_dict:
if streaming:
train_dataset = IterableDatasetDict({"stsb-1": train_dataset_1, "stsb-2": train_dataset_2})
else:
train_dataset = DatasetDict({"stsb-1": train_dataset_1, "stsb-2": train_dataset_2})
else:
if streaming:
train_dataset = train_dataset_1.to_iterable_dataset()
else:
train_dataset = train_dataset_1
if eval_dict:
if streaming:
eval_dataset = IterableDatasetDict({"stsb-1": eval_dataset_1, "stsb-2": eval_dataset_2})
else:
eval_dataset = DatasetDict({"stsb-1": eval_dataset_1, "stsb-2": eval_dataset_2})
else:
if streaming:
eval_dataset = eval_dataset_1.to_iterable_dataset()
else:
eval_dataset = eval_dataset_1
def upper_transform(batch):
for column_name, column in batch.items():
batch[column_name] = [text.upper() for text in column]
return batch
if add_transform:
if streaming:
if train_dict:
train_dataset = IterableDatasetDict(
{
dataset_name: dataset.map(upper_transform, batched=True, features=dataset.features)
for dataset_name, dataset in train_dataset.items()
}
)
else:
train_dataset = train_dataset.map(upper_transform, batched=True, features=train_dataset.features)
if eval_dict:
eval_dataset = IterableDatasetDict(
{
dataset_name: dataset.map(upper_transform, batched=True, features=dataset.features)
for dataset_name, dataset in eval_dataset.items()
}
)
else:
eval_dataset = eval_dataset.map(upper_transform, batched=True, features=eval_dataset.features)
else:
train_dataset.set_transform(upper_transform)
eval_dataset.set_transform(upper_transform)
if loss_dict:
loss = {
"stsb-1": loss,
"stsb-2": loss,
}
# Variables to more easily track the expected outputs. Uppercased if add_transform is True as we expect
# the transform to be applied to the data.
all_train_1_1 = {s.upper() if add_transform else s for s in train_dataset_1["sentence1"]}
all_train_1_2 = {s.upper() if add_transform else s for s in train_dataset_1["sentence2"]}
all_train_2_1 = {s.upper() if add_transform else s for s in train_dataset_2["sentence1"]}
all_train_2_2 = {s.upper() if add_transform else s for s in train_dataset_2["sentence2"]}
all_eval_1_1 = {s.upper() if add_transform else s for s in eval_dataset_1["sentence1"]}
all_eval_1_2 = {s.upper() if add_transform else s for s in eval_dataset_1["sentence2"]}
all_eval_2_1 = {s.upper() if add_transform else s for s in eval_dataset_2["sentence1"]}
all_eval_2_2 = {s.upper() if add_transform else s for s in eval_dataset_2["sentence2"]}
all_train_1 = all_train_1_1 | all_train_1_2
all_train_2 = all_train_2_1 | all_train_2_2
all_eval_1 = all_eval_1_1 | all_eval_1_2
all_eval_2 = all_eval_2_1 | all_eval_2_2
all_train = all_train_1 | all_train_2
all_eval = all_eval_1 | all_eval_2
if prompts == {
"stsb-1": {"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "},
"stsb-2": {"sentence1": "Prompt 3: ", "sentence2": "Prompt 4: "},
} and (train_dict, eval_dict) != (True, True):
context = pytest.raises(
ValueError,
match="The prompts provided to the trainer are a nested dictionary. In this setting, the first "
"level of the dictionary should map to dataset names and the second level to column names. "
"However, as the provided dataset is a not a DatasetDict, no dataset names can be inferred. "
"The keys to the provided prompts dictionary are .*",
)
else:
context = nullcontext()
with tempfile.TemporaryDirectory() as temp_dir:
args = SparseEncoderTrainingArguments(
output_dir=str(temp_dir),
prompts=prompts,
max_steps=4 if train_dict else 2,
eval_steps=4 if train_dict else 2,
eval_strategy="steps",
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
report_to=["none"],
)
trainer = SparseEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
)
tracked_texts.clear()
datacollator_keys = set()
old_compute_loss = trainer.compute_loss
def compute_loss_tracker(model, inputs, **kwargs):
datacollator_keys.update(set(inputs.keys()))
return old_compute_loss(model, inputs, **kwargs)
trainer.compute_loss = compute_loss_tracker
with context:
trainer.train()
if not isinstance(context, nullcontext):
return
# prompt_length keys may appear in the batch when prompts are provided (Transformer.preprocess always
# computes them), but SpladePooling simply ignores them. Only Pooling uses them when include_prompt=False.
# We only need the dataset_name if the loss requires it, or the prompts are a nested dictionary
if (train_dict or eval_dict) and (loss_dict or (isinstance(prompts, dict))):
assert "dataset_name" in datacollator_keys
else:
assert "dataset_name" not in datacollator_keys
if prompts is None:
if (train_dict, eval_dict) == (False, False):
expected = all_train_1 | all_eval_1
elif (train_dict, eval_dict) == (True, False):
expected = all_train | all_eval_1
elif (train_dict, eval_dict) == (False, True):
expected = all_train_1 | all_eval
elif (train_dict, eval_dict) == (True, True):
expected = all_train | all_eval
elif prompts == "Prompt: ":
if (train_dict, eval_dict) == (False, False):
expected = {prompts + sample for sample in all_train_1} | {prompts + sample for sample in all_eval_1}
elif (train_dict, eval_dict) == (True, False):
expected = {prompts + sample for sample in all_train} | {prompts + sample for sample in all_eval_1}
elif (train_dict, eval_dict) == (False, True):
expected = {prompts + sample for sample in all_train_1} | {prompts + sample for sample in all_eval}
elif (train_dict, eval_dict) == (True, True):
expected = {prompts + sample for sample in all_train} | {prompts + sample for sample in all_eval}
elif prompts == {"stsb-1": "Prompt 1: ", "stsb-2": "Prompt 2: "}:
# If we don't have dataset dictionaries, the prompts will be seen as column names
if (train_dict, eval_dict) == (False, False):
expected = all_train_1 | all_eval_1
elif (train_dict, eval_dict) == (True, False):
expected = (
{prompts["stsb-1"] + sample for sample in all_train_1}
| {prompts["stsb-2"] + sample for sample in all_train_2}
| all_eval_1
)
elif (train_dict, eval_dict) == (False, True):
expected = (
all_train_1
| {prompts["stsb-1"] + sample for sample in all_eval_1}
| {prompts["stsb-2"] + sample for sample in all_eval_2}
)
elif (train_dict, eval_dict) == (True, True):
expected = (
{prompts["stsb-1"] + sample for sample in all_train_1}
| {prompts["stsb-2"] + sample for sample in all_train_2}
| {prompts["stsb-1"] + sample for sample in all_eval_1}
| {prompts["stsb-2"] + sample for sample in all_eval_2}
)
elif prompts == {"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "}:
if (train_dict, eval_dict) == (False, False):
expected = (
{prompts["sentence1"] + sample for sample in all_train_1_1}
| {prompts["sentence2"] + sample for sample in all_train_1_2}
| {prompts["sentence1"] + sample for sample in all_eval_1_1}
| {prompts["sentence2"] + sample for sample in all_eval_1_2}
)
elif (train_dict, eval_dict) == (True, False):
expected = (
{prompts["sentence1"] + sample for sample in all_train_1_1}
| {prompts["sentence2"] + sample for sample in all_train_1_2}
| {prompts["sentence1"] + sample for sample in all_train_2_1}
| {prompts["sentence2"] + sample for sample in all_train_2_2}
| {prompts["sentence1"] + sample for sample in all_eval_1_1}
| {prompts["sentence2"] + sample for sample in all_eval_1_2}
)
elif (train_dict, eval_dict) == (False, True):
expected = (
{prompts["sentence1"] + sample for sample in all_train_1_1}
| {prompts["sentence2"] + sample for sample in all_train_1_2}
| {prompts["sentence1"] + sample for sample in all_eval_1_1}
| {prompts["sentence2"] + sample for sample in all_eval_1_2}
| {prompts["sentence1"] + sample for sample in all_eval_2_1}
| {prompts["sentence2"] + sample for sample in all_eval_2_2}
)
elif (train_dict, eval_dict) == (True, True):
expected = (
{prompts["sentence1"] + sample for sample in all_train_1_1}
| {prompts["sentence2"] + sample for sample in all_train_1_2}
| {prompts["sentence1"] + sample for sample in all_train_2_1}
| {prompts["sentence2"] + sample for sample in all_train_2_2}
| {prompts["sentence1"] + sample for sample in all_eval_1_1}
| {prompts["sentence2"] + sample for sample in all_eval_1_2}
| {prompts["sentence1"] + sample for sample in all_eval_2_1}
| {prompts["sentence2"] + sample for sample in all_eval_2_2}
)
elif prompts == {
"stsb-1": {"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "},
"stsb-2": {"sentence1": "Prompt 3: ", "sentence2": "Prompt 4: "},
}:
# All other cases are tested above with the ValueError context
if (train_dict, eval_dict) == (True, True):
expected = (
{prompts["stsb-1"]["sentence1"] + sample for sample in all_train_1_1}
| {prompts["stsb-1"]["sentence2"] + sample for sample in all_train_1_2}
| {prompts["stsb-2"]["sentence1"] + sample for sample in all_train_2_1}
| {prompts["stsb-2"]["sentence2"] + sample for sample in all_train_2_2}
| {prompts["stsb-1"]["sentence1"] + sample for sample in all_eval_1_1}
| {prompts["stsb-1"]["sentence2"] + sample for sample in all_eval_1_2}
| {prompts["stsb-2"]["sentence1"] + sample for sample in all_eval_2_1}
| {prompts["stsb-2"]["sentence2"] + sample for sample in all_eval_2_2}
)
assert set(tracked_texts) == expected