from __future__ import annotations import json import logging import re import tempfile from pathlib import Path import numpy as np import pytest import torch from huggingface_hub import CommitInfo, HfApi, RepoUrl from packaging.version import Version, parse from pytest import FixtureRequest from transformers import __version__ as transformers_version from sentence_transformers import CrossEncoder from sentence_transformers.util import fullname from sentence_transformers.util.decorators import ( cross_encoder_init_args_decorator, cross_encoder_predict_rank_args_decorator, ) def test_classifier_dropout_is_set() -> None: model = CrossEncoder("cross-encoder-testing/reranker-bert-tiny-gooaq-bce", classifier_dropout=0.1234) assert model.config.classifier_dropout == 0.1234 assert model.model.config.classifier_dropout == 0.1234 def test_classifier_dropout_default_value() -> None: model = CrossEncoder("cross-encoder-testing/reranker-bert-tiny-gooaq-bce") assert model.config.classifier_dropout is None assert model.model.config.classifier_dropout is None def test_load_with_revision() -> None: model_name = "sentence-transformers-testing/stsb-bert-tiny-safetensors" main_model = CrossEncoder(model_name, num_labels=1, revision="main") latest_model = CrossEncoder( model_name, num_labels=1, revision="f3cb857cba53019a20df283396bcca179cf051a4", ) older_model = CrossEncoder( model_name, num_labels=1, revision="ba33022fdf0b0fc2643263f0726f44d0a07d0e24", ) # Set the classifier.bias and classifier.weight equal among models. This # is needed because the AutoModelForSequenceClassification randomly initializes # the classifier.bias and classifier.weight for each (model) initialization. # The test is only possible if all models have the same classifier.bias # and classifier.weight parameters. latest_model.model.classifier.bias = main_model.model.classifier.bias latest_model.model.classifier.weight = main_model.model.classifier.weight older_model.model.classifier.bias = main_model.model.classifier.bias older_model.model.classifier.weight = main_model.model.classifier.weight test_sentences = [["Hello there!", "Hello, World!"]] main_prob = main_model.predict(test_sentences, convert_to_tensor=True) assert torch.equal(main_prob, latest_model.predict(test_sentences, convert_to_tensor=True)) assert not torch.equal(main_prob, older_model.predict(test_sentences, convert_to_tensor=True)) @pytest.mark.parametrize( argnames="return_documents", argvalues=[True, False], ids=["return-docs", "no-return-docs"], ) def test_rank(return_documents: bool, request: FixtureRequest) -> None: model = CrossEncoder("cross-encoder/stsb-distilroberta-base") # We want to compute the similarity between the query sentence query = "A man is eating pasta." # With all sentences in the corpus corpus = [ "A man is eating food.", "A man is eating a piece of bread.", "The girl is carrying a baby.", "A man is riding a horse.", "A woman is playing violin.", "Two men pushed carts through the woods.", "A man is riding a white horse on an enclosed ground.", "A monkey is playing drums.", "A cheetah is running behind its prey.", ] expected_ranking = [0, 1, 3, 6, 2, 5, 7, 4, 8] # 1. We rank all sentences in the corpus for the query ranks = model.rank(query=query, documents=corpus, return_documents=return_documents) if request.node.callspec.id == "return-docs": assert {*corpus} == {rank.get("text") for rank in ranks} pred_ranking = [rank["corpus_id"] for rank in ranks] assert pred_ranking == expected_ranking def test_rank_multiple_labels(): model = CrossEncoder("cross-encoder/nli-MiniLM2-L6-H768") with pytest.raises( ValueError, match=re.escape( "CrossEncoder.rank() only works for models with num_labels=1. " "Consider using CrossEncoder.predict() with input pairs instead." ), ): model.rank( query="A man is eating pasta.", documents=[ "A man is eating food.", "A man is eating a piece of bread.", "The girl is carrying a baby.", ], ) def test_predict_softmax(): model = CrossEncoder("cross-encoder/nli-MiniLM2-L6-H768") query = "A man is eating pasta." # With all sentences in the corpus corpus = [ "A man is eating food.", "A man is eating a piece of bread.", "The girl is carrying a baby.", "A man is riding a horse.", ] scores = model.predict([(query, doc) for doc in corpus], apply_softmax=True, convert_to_tensor=True) assert torch.isclose(scores.sum(1), torch.ones(len(corpus), device=scores.device)).all() scores = model.predict([(query, doc) for doc in corpus], apply_softmax=False, convert_to_tensor=True) assert not torch.isclose(scores.sum(1), torch.ones(len(corpus), device=scores.device)).all() @pytest.mark.parametrize( "model_name", ["cross-encoder-testing/reranker-bert-tiny-gooaq-bce", "cross-encoder/nli-MiniLM2-L6-H768"] ) def test_predict_single_input(model_name: str): model = CrossEncoder(model_name) nested_pair_score = model.predict([["A man is eating pasta.", "A man is eating food."]]) assert isinstance(nested_pair_score, np.ndarray) if model.num_labels == 1: assert nested_pair_score.shape == (1,) else: assert nested_pair_score.shape == (1, model.num_labels) pair_score = model.predict(["A man is eating pasta.", "A man is eating food."]) if model.num_labels == 1: assert isinstance(pair_score, np.float32) else: assert isinstance(pair_score, np.ndarray) assert pair_score.shape == (model.num_labels,) def test_is_singular_input_numpy_1d_pair(reranker_bert_tiny_model: CrossEncoder) -> None: """A 1D numpy string array represents a single (query, document) pair.""" assert reranker_bert_tiny_model.is_singular_input(np.array(["query", "document"])) is True def test_is_singular_input_numpy_2d_pairs(reranker_bert_tiny_model: CrossEncoder) -> None: """A 2D numpy string array is a batch of pairs.""" assert reranker_bert_tiny_model.is_singular_input(np.array([["q1", "d1"], ["q2", "d2"]])) is False def test_is_singular_input_numpy_empty(reranker_bert_tiny_model: CrossEncoder) -> None: """An empty 1D string ndarray is an empty batch, not a singular pair, matching ``predict([])``.""" assert reranker_bert_tiny_model.is_singular_input(np.array([], dtype=str)) is False def test_predict_numpy_empty(reranker_bert_tiny_model: CrossEncoder) -> None: """Predicting on an empty string ndarray should return an empty array, like ``predict([])``.""" scores = reranker_bert_tiny_model.predict(np.array([], dtype=str), show_progress_bar=False) expected = reranker_bert_tiny_model.predict([], show_progress_bar=False) assert scores.shape == (0,) assert np.array_equal(scores, expected) def test_predict_numpy_1d_pair(reranker_bert_tiny_model: CrossEncoder) -> None: """Predicting on a 1D numpy string array (a single pair) should match the tuple equivalent and return a scalar score. Exercises the singular-branch .tolist() conversion.""" model = reranker_bert_tiny_model pair = np.array(["what is AI?", "AI is artificial intelligence."]) score = model.predict(pair, show_progress_bar=False) expected = model.predict(tuple(pair.tolist()), show_progress_bar=False) assert isinstance(score, np.float32) assert np.allclose(score, expected) def test_predict_numpy_2d_pairs(reranker_bert_tiny_model: CrossEncoder) -> None: """Predicting on a 2D numpy string array should match predicting on the equivalent nested list.""" pairs = np.array([["what is AI?", "AI is artificial intelligence."], ["what is ML?", "ML is machine learning."]]) scores = reranker_bert_tiny_model.predict(pairs, show_progress_bar=False) expected = reranker_bert_tiny_model.predict(pairs.tolist(), show_progress_bar=False) assert scores.shape == (2,) assert np.allclose(scores, expected) def test_predict_batch_size_1(reranker_bert_tiny_model: CrossEncoder) -> None: """Regression test: batch_size=1 with num_labels=1 used to fail because squeeze produced a 0-d tensor. Some models (e.g. jinaai/jina-reranker-m0) return scores with shape [batch_size] instead of [batch_size, 1]. With batch_size=1, squeeze(-1) would collapse [1] to a 0-d scalar, causing .extend() to fail. We mock forward to reproduce this by stripping the trailing dimension. """ model = reranker_bert_tiny_model pairs = [ ["A man is eating pasta.", "A man is eating food."], ["The girl is carrying a baby.", "A man is riding a horse."], ] original_forward = model.forward def forward_without_trailing_dim(features, **kwargs): out = original_forward(features, **kwargs) # Simulate models that return [batch_size] instead of [batch_size, 1] out["scores"] = out["scores"].squeeze(-1) return out model.forward = forward_without_trailing_dim scores = model.predict(pairs, batch_size=1) assert isinstance(scores, np.ndarray) assert scores.shape == (2,) @pytest.mark.parametrize("convert_to_numpy", [True, False]) @pytest.mark.parametrize("convert_to_tensor", [True, False]) def test_empty_predict(reranker_bert_tiny_model: CrossEncoder, convert_to_numpy: bool, convert_to_tensor: bool): model = reranker_bert_tiny_model result = model.predict([], convert_to_numpy=convert_to_numpy, convert_to_tensor=convert_to_tensor) if convert_to_tensor: assert isinstance(result, torch.Tensor) assert result.numel() == 0 assert result.device == model.model.device elif convert_to_numpy: assert isinstance(result, np.ndarray) assert result.size == 0 else: assert result == [] @pytest.mark.parametrize("convert_to_tensor", [True, False]) @pytest.mark.parametrize("convert_to_numpy", [True, False]) def test_predict_output_types(convert_to_tensor: bool, convert_to_numpy: bool) -> None: model = CrossEncoder("cross-encoder-testing/reranker-bert-tiny-gooaq-bce") embeddings = model.predict( [["One sentence", "Another sentence"]], convert_to_tensor=convert_to_tensor, convert_to_numpy=convert_to_numpy, ) if convert_to_tensor: assert embeddings[0].dtype == torch.float32 assert isinstance(embeddings, torch.Tensor) elif convert_to_numpy: assert embeddings[0].dtype == np.float32 assert isinstance(embeddings, np.ndarray) else: assert embeddings[0].dtype == torch.float32 assert isinstance(embeddings, list) @pytest.mark.parametrize("safe_serialization", [True, False, None]) def test_safe_serialization(safe_serialization: bool) -> None: with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as cache_folder: model = CrossEncoder("cross-encoder-testing/reranker-bert-tiny-gooaq-bce") if safe_serialization: model.save_pretrained(cache_folder, safe_serialization=safe_serialization) model_files = list(Path(cache_folder).glob("**/model.safetensors")) assert 1 == len(model_files) elif safe_serialization is None: model.save_pretrained(cache_folder) model_files = list(Path(cache_folder).glob("**/model.safetensors")) assert 1 == len(model_files) else: # For transformers v5.0, safe_serialization is quietly ignored if parse(transformers_version) < Version("5.0.0dev0"): model.save_pretrained(cache_folder, safe_serialization=safe_serialization) model_files = list(Path(cache_folder).glob("**/pytorch_model.bin")) assert 1 == len(model_files) def test_bfloat16() -> None: model = CrossEncoder( "cross-encoder-testing/reranker-bert-tiny-gooaq-bce", automodel_args={"torch_dtype": torch.bfloat16} ) score = model.predict([["Hello there!", "Hello, World!"]]) assert isinstance(score, np.ndarray) ranking = model.rank("Hello there!", ["Hello, World!", "Heya!"]) assert isinstance(ranking, list) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.") @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_device_assignment(device): model = CrossEncoder("cross-encoder-testing/reranker-bert-tiny-gooaq-bce", device=device) assert model.device.type == device @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.") def test_device_switching(): # test assignment using .to model = CrossEncoder("cross-encoder-testing/reranker-bert-tiny-gooaq-bce", device="cpu") assert model.device.type == "cpu" assert model.model.device.type == "cpu" model.to("cuda") assert model.device.type == "cuda" assert model.model.device.type == "cuda" del model torch.cuda.empty_cache() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.") def test_target_device_backwards_compat(): model = CrossEncoder("cross-encoder-testing/reranker-bert-tiny-gooaq-bce", device="cpu") assert model.device.type == "cpu" assert model._target_device.type == "cpu" model._target_device = "cuda" assert model.device.type == "cuda" def test_num_labels_fresh_model(): model = CrossEncoder("sentence-transformers-testing/stsb-bert-tiny-safetensors") assert model.num_labels == 1 def test_push_to_hub( reranker_bert_tiny_model: CrossEncoder, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture ) -> None: model = reranker_bert_tiny_model def mock_create_repo(self, repo_id, **kwargs): return RepoUrl(f"https://huggingface.co/{repo_id}") mock_upload_folder_kwargs = {} def mock_upload_folder(self, **kwargs): nonlocal mock_upload_folder_kwargs mock_upload_folder_kwargs = kwargs commit_hash = "123456" if kwargs.get("revision") is None else "678901" commit_info_kwargs = { "commit_url": f"https://huggingface.co/{kwargs.get('repo_id')}/commit/{commit_hash}", "commit_message": "commit_message", "commit_description": "commit_description", "oid": "oid", "pr_url": f"https://huggingface.co/{kwargs.get('repo_id')}/discussions/123", } try: return CommitInfo(**commit_info_kwargs) except TypeError: # Required as of https://github.com/huggingface/huggingface_hub/pull/3679 return CommitInfo(**commit_info_kwargs, _endpoint=None) def mock_create_branch(self, repo_id, branch, revision=None, **kwargs): return None monkeypatch.setattr(HfApi, "create_repo", mock_create_repo) monkeypatch.setattr(HfApi, "upload_folder", mock_upload_folder) monkeypatch.setattr(HfApi, "create_branch", mock_create_branch) url = model.push_to_hub("cross-encoder-testing/stsb-distilroberta-base") assert mock_upload_folder_kwargs["repo_id"] == "cross-encoder-testing/stsb-distilroberta-base" assert url == "https://huggingface.co/cross-encoder-testing/stsb-distilroberta-base/commit/123456" mock_upload_folder_kwargs.clear() url = model.push_to_hub("cross-encoder-testing/stsb-distilroberta-base", revision="revision_test") assert mock_upload_folder_kwargs["repo_id"] == "cross-encoder-testing/stsb-distilroberta-base" assert mock_upload_folder_kwargs["revision"] == "revision_test" assert url == "https://huggingface.co/cross-encoder-testing/stsb-distilroberta-base/commit/678901" mock_upload_folder_kwargs.clear() url = model.push_to_hub("cross-encoder-testing/stsb-distilroberta-base", create_pr=True) assert mock_upload_folder_kwargs["repo_id"] == "cross-encoder-testing/stsb-distilroberta-base" assert url == "https://huggingface.co/cross-encoder-testing/stsb-distilroberta-base/discussions/123" mock_upload_folder_kwargs.clear() @pytest.mark.parametrize( ["in_args", "in_kwargs", "out_args", "out_kwargs"], [ [ tuple(), {"model_name": "cross-encoder-testing/reranker-bert-tiny-gooaq-bce", "classifier_dropout": 0.1234}, tuple(), { "model_name_or_path": "cross-encoder-testing/reranker-bert-tiny-gooaq-bce", "config_kwargs": {"classifier_dropout": 0.1234}, }, ], [ ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), {"classifier_dropout": 0.1234}, ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), {"config_kwargs": {"classifier_dropout": 0.1234}}, ], [ ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), { "automodel_args": {"foo": "bar"}, "tokenizer_args": {"foo": "baz"}, }, ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), { "model_kwargs": {"foo": "bar"}, "processor_kwargs": {"foo": "baz"}, }, ], [ ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), { "config_args": {"foo": "bar"}, "cache_dir": "local_tmp", }, ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), { "config_kwargs": {"foo": "bar"}, "cache_folder": "local_tmp", }, ], [ ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), { "automodel_args": {"foo": "bar"}, "model_kwargs": {"faa": "baz"}, }, ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), { "model_kwargs": {"faa": "baz"}, }, ], [ ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), { "default_activation_function": "torch.nn.Sigmoid", }, ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), { "activation_fn": "torch.nn.Sigmoid", }, ], [tuple(), {}, tuple(), {}], [ ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), {}, ("cross-encoder-testing/reranker-bert-tiny-gooaq-bce",), {}, ], [ tuple(), { "model_name": "cross-encoder-testing/reranker-bert-tiny-gooaq-bce", "automodel_args": {"foo": "bar"}, "tokenizer_args": {"foo": "baz"}, "config_args": {"foo": "bar"}, "cache_dir": "local_tmp", }, tuple(), { "model_name_or_path": "cross-encoder-testing/reranker-bert-tiny-gooaq-bce", "model_kwargs": {"foo": "bar"}, "processor_kwargs": {"foo": "baz"}, "config_kwargs": {"foo": "bar"}, "cache_folder": "local_tmp", }, ], ], ) def test_init_args_decorator( monkeypatch: pytest.MonkeyPatch, in_args: tuple, in_kwargs: dict, out_args: tuple, out_kwargs: dict ): decorated_out_args = None decorated_out_kwargs = None @cross_encoder_init_args_decorator def mock_init(self, *args, **kwargs): nonlocal decorated_out_args nonlocal decorated_out_kwargs decorated_out_args = args decorated_out_kwargs = kwargs return None monkeypatch.setattr(CrossEncoder, "__init__", mock_init) CrossEncoder(*in_args, **in_kwargs) assert decorated_out_args == out_args assert decorated_out_kwargs == out_kwargs @pytest.mark.parametrize( ["in_kwargs", "out_kwargs"], [ [ {"inputs": [["Hello there!", "Hello, World!"]], "num_workers": 2}, {"inputs": [["Hello there!", "Hello, World!"]]}, ], [ { "inputs": [["Hello there!", "Hello, World!"]], "activation_fct": torch.nn.Identity, "activation_fn": torch.nn.Sigmoid, }, {"inputs": [["Hello there!", "Hello, World!"]], "activation_fn": torch.nn.Sigmoid}, ], [ {"sentences": [["Hello there!", "Hello, World!"]]}, {"inputs": [["Hello there!", "Hello, World!"]]}, ], ], ) def test_predict_rank_args_decorator( reranker_bert_tiny_model: CrossEncoder, monkeypatch: pytest.MonkeyPatch, caplog, in_kwargs: dict, out_kwargs: dict ): model = reranker_bert_tiny_model decorated_out_kwargs = None @cross_encoder_predict_rank_args_decorator def mock_predict(self, *args, **kwargs): nonlocal decorated_out_kwargs decorated_out_kwargs = kwargs return None monkeypatch.setattr(CrossEncoder, "predict", mock_predict) with caplog.at_level(logging.WARNING): model.predict(**in_kwargs) assert caplog.text != "" assert decorated_out_kwargs == out_kwargs def test_logger_warning(caplog): model_name = "cross-encoder-testing/reranker-bert-tiny-gooaq-bce" with caplog.at_level(logging.WARNING): CrossEncoder(model_name, classifier_dropout=0.1234) assert "`classifier_dropout` argument is deprecated" in caplog.text with caplog.at_level(logging.WARNING): CrossEncoder(model_name, automodel_args={"torch_dtype": torch.float32}) assert "`automodel_args` argument was renamed and is now deprecated" in caplog.text with caplog.at_level(logging.WARNING): CrossEncoder(model_name, tokenizer_args={"model_max_length": 8192}) assert "`tokenizer_args` argument was renamed and is now deprecated" in caplog.text with caplog.at_level(logging.WARNING): CrossEncoder(model_name, config_args={"classifier_dropout": 0.2}) assert "`config_args` argument was renamed and is now deprecated" in caplog.text @pytest.mark.parametrize( ["num_labels", "activation_fn", "saved_activation_fn"], [ [ 1, torch.nn.Sigmoid(), "torch.nn.modules.activation.Sigmoid", ], [ 1, torch.nn.Identity(), "torch.nn.modules.linear.Identity", ], [ 1, torch.nn.Tanh(), "torch.nn.modules.activation.Tanh", ], [ 1, torch.nn.Softmax(), "torch.nn.modules.activation.Softmax", ], [ 1, None, "torch.nn.modules.activation.Sigmoid", ], [ 3, None, "torch.nn.modules.linear.Identity", ], ], ) def test_load_activation_fn_from_kwargs(num_labels: int, activation_fn: str, saved_activation_fn: str, tmp_path: Path): model = CrossEncoder( "sentence-transformers-testing/stsb-bert-tiny-safetensors", num_labels=num_labels, activation_fn=activation_fn ) assert fullname(model.activation_fn) == saved_activation_fn model.save_pretrained(tmp_path) with open(tmp_path / "config_sentence_transformers.json") as f: config = json.load(f) assert config["activation_fn"] == saved_activation_fn loaded_model = CrossEncoder(str(tmp_path)) assert fullname(loaded_model.activation_fn) == saved_activation_fn # Setting the activation function via a predict call only updates it for that call loaded_model.predict([["Hello there!", "Hello, World!"]], activation_fn=torch.nn.Identity()) assert fullname(loaded_model.activation_fn) == saved_activation_fn # But we can also override it again when loading the model loaded_model = CrossEncoder(str(tmp_path), activation_fn=torch.nn.Identity()) assert fullname(loaded_model.activation_fn) == "torch.nn.modules.linear.Identity" @pytest.mark.parametrize( "tanh_model_name", [ "cross-encoder-testing/reranker-bert-tiny-gooaq-bce-tanh-v3", "cross-encoder-testing/reranker-bert-tiny-gooaq-bce-tanh-v4", ], ) def test_load_activation_fn_from_config(tanh_model_name: str, tmp_path): saved_activation_fn = "torch.nn.modules.activation.Tanh" model = CrossEncoder(tanh_model_name) assert fullname(model.activation_fn) == saved_activation_fn model.save_pretrained(tmp_path) with open(tmp_path / "config_sentence_transformers.json") as f: config = json.load(f) assert config["activation_fn"] == saved_activation_fn loaded_model = CrossEncoder(str(tmp_path)) assert fullname(loaded_model.activation_fn) == saved_activation_fn def test_load_activation_fn_from_config_custom(reranker_bert_tiny_model: CrossEncoder, tmp_path: Path, caplog): model = reranker_bert_tiny_model model.save_pretrained(tmp_path) with open(tmp_path / "config_sentence_transformers.json") as f: config = json.load(f) config["activation_fn"] = "sentence_transformers.custom.activations.CustomActivation" with open(tmp_path / "config_sentence_transformers.json", "w") as f: json.dump(config, f) with caplog.at_level(logging.WARNING): CrossEncoder(str(tmp_path)) assert ( "Activation function path 'sentence_transformers.custom.activations.CustomActivation' is not trusted, using default activation function instead." in caplog.text ) # If we use trust_remote_code, it'll try to load the custom activation function, which doesn't exist with pytest.raises(ModuleNotFoundError): model = CrossEncoder(str(tmp_path), trust_remote_code=True) def test_default_activation_fn(reranker_bert_tiny_model: CrossEncoder): model = reranker_bert_tiny_model assert fullname(model.activation_fn) == "torch.nn.modules.activation.Sigmoid" with pytest.warns( DeprecationWarning, match="The `default_activation_function` property was renamed and is now deprecated.*" ): assert fullname(model.default_activation_function) == "torch.nn.modules.activation.Sigmoid" def test_bge_reranker_max_length(): model = CrossEncoder("BAAI/bge-reranker-base") assert model.max_length == 512 assert model.tokenizer.model_max_length == 512 model.max_length = 256 assert model.max_length == 256 assert model.tokenizer.model_max_length == 256 def test_predict_with_dataset_column(reranker_bert_tiny_model: CrossEncoder) -> None: """Test that predict can handle a dataset column as input.""" model = reranker_bert_tiny_model from datasets import Dataset # Create a simple dataset with a text column dataset = Dataset.from_dict( { "text": [ ["This is the start of a pair.", "And this the end."], ["This is a second pair.", "And this the end of the second pair."], ] } ) # Encode the dataset column embeddings = model.predict(dataset["text"], convert_to_tensor=True) # Check the shape of the embeddings assert embeddings.shape == (2,) # Test suite converted from demo_3406_simple_og.py def format_queries(query, instruction=None): """Helper function to format queries with the template.""" prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' if instruction is None: instruction = "Given a web search query, retrieve relevant passages that answer the query" return f"{prefix}: {instruction}\n: {query}\n" def format_document(document): """Helper function to format documents with the template.""" suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" return f": {document}{suffix}" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_qwen3_reranker_formatted_pairs(): """Test Qwen3 Reranker with manually formatted query-document pairs.""" model = CrossEncoder("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", activation_fn=torch.nn.Identity()) task = "Given a web search query, retrieve relevant passages that answer the query" queries = [ "Which planet is known as the Red Planet?", "Which planet is known as the Red Planet?", "Which planet is known as the Red Planet?", "Which planet is known as the Red Planet?", ] documents = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet.", ] pairs = [[format_queries(query, task), format_document(doc)] for query, doc in zip(queries, documents)] scores = model.predict(pairs) expected_scores = [-3.109297752380371, 7.120389938354492, -0.3787546157836914, 3.541637420654297] # Assert scores match expected values with tolerance assert scores == pytest.approx(expected_scores, abs=1e-4) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_qwen3_reranker_with_chat_template(): """Test Qwen3 Reranker with Chat template.""" chat_template = """\ <|im_start|>system Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|> <|im_start|>user : {{ messages | selectattr("role", "eq", "system") | map(attribute="content") | first | default("Given a web search query, retrieve relevant passages that answer the query") }} : {{ messages | selectattr("role", "eq", "query") | map(attribute="content") | first }} : {{ messages | selectattr("role", "eq", "document") | map(attribute="content") | first }}<|im_end|> <|im_start|>assistant \n\n\n\n\n""" task = "Given a web search query, retrieve relevant passages that answer the query" model = CrossEncoder( "tomaarsen/Qwen3-Reranker-0.6B-seq-cls", activation_fn=torch.nn.Identity(), processor_kwargs={"chat_template": chat_template}, prompts={"web_search": task}, default_prompt_name="web_search", ) query = "Which planet is known as the Red Planet?" documents = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet.", ] pairs = [(query, doc) for doc in documents] scores = model.predict(pairs) expected_scores = [-3.109297752380371, 7.120389938354492, -0.3787546157836914, 3.541637420654297] # Assert scores match expected values with tolerance assert scores == pytest.approx(expected_scores, abs=1e-4) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_qwen3_reranker_original_with_identity_activation(): """Test original Qwen3 Reranker with Identity activation function.""" chat_template = """\ <|im_start|>system Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|> <|im_start|>user : {{ messages | selectattr("role", "eq", "system") | map(attribute="content") | first | default("Given a web search query, retrieve relevant passages that answer the query") }} : {{ messages | selectattr("role", "eq", "query") | map(attribute="content") | first }} : {{ messages | selectattr("role", "eq", "document") | map(attribute="content") | first }}<|im_end|> <|im_start|>assistant \n\n\n\n\n""" task = "Given a web search query, retrieve relevant passages that answer the query" model = CrossEncoder( "Qwen/Qwen3-Reranker-0.6B", prompts={"web_search": task}, default_prompt_name="web_search", activation_fn=torch.nn.Identity(), model_kwargs={"torch_dtype": torch.float32}, processor_kwargs={"chat_template": chat_template}, ) assert model.dtype == torch.float32 query = "Which planet is known as the Red Planet?" documents = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet.", ] pairs = [[query, doc] for doc in documents] scores = model.predict(pairs) expected_scores = [-3.109297752380371, 7.120389938354492, -0.3787546157836914, 3.541637420654297] # Assert scores match expected values with tolerance assert scores == pytest.approx(expected_scores, abs=1e-4) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_qwen3_reranker_original_without_prompt(): """Test original Qwen3 Reranker with Identity activation function.""" chat_template = """\ <|im_start|>system Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|> <|im_start|>user : {{ messages | selectattr("role", "eq", "system") | map(attribute="content") | first | default("Given a web search query, retrieve relevant passages that answer the query") }} : {{ messages | selectattr("role", "eq", "query") | map(attribute="content") | first }} : {{ messages | selectattr("role", "eq", "document") | map(attribute="content") | first }}<|im_end|> <|im_start|>assistant \n\n\n\n\n""" model = CrossEncoder( "Qwen/Qwen3-Reranker-0.6B", activation_fn=torch.nn.Identity(), model_kwargs={"torch_dtype": torch.float32}, processor_kwargs={"chat_template": chat_template}, ) assert model.dtype == torch.float32 query = "Which planet is known as the Red Planet?" documents = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet.", ] pairs = [[query, doc] for doc in documents] scores = model.predict(pairs) expected_scores = [-3.109297752380371, 7.120389938354492, -0.3787546157836914, 3.541637420654297] # Assert scores match expected values with tolerance assert scores == pytest.approx(expected_scores, abs=1e-4)