| from __future__ import annotations |
|
|
| import torch |
|
|
| from sentence_transformers.sparse_encoder import SparseEncoder |
| from sentence_transformers.sparse_encoder.modules import Router, SparseStaticEmbedding, SpladePooling, Transformer |
|
|
|
|
| def test_opensearch_v2_distill_similarity(): |
| """Test OpenSearch v2 distill model produces expected similarity scores.""" |
| |
| doc_encoder = Transformer( |
| "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill", transformer_task="fill-mask" |
| ) |
| router = Router.for_query_document( |
| query_modules=[ |
| SparseStaticEmbedding.from_json( |
| "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill", |
| tokenizer=doc_encoder.tokenizer, |
| frozen=True, |
| ), |
| ], |
| document_modules=[ |
| doc_encoder, |
| SpladePooling("max"), |
| ], |
| ) |
|
|
| model = SparseEncoder( |
| modules=[router], |
| similarity_fn_name="dot", |
| ) |
|
|
| |
| query = "What's the weather in ny now?" |
| document = "Currently New York is rainy." |
|
|
| |
| query_embed = model.encode_query(query) |
| document_embed = model.encode_document(document) |
| similarity = model.similarity(query_embed, document_embed).cpu() |
|
|
| |
| expected_similarity = 17.5307 |
| tolerance = 1e-3 |
|
|
| |
| assert torch.allclose(similarity, torch.tensor([[expected_similarity]]), atol=tolerance, rtol=0.01), ( |
| f"Expected similarity ~{expected_similarity}, got {similarity.item():.4f}" |
| ) |
|
|
| |
| decoded_query = model.decode(query_embed, top_k=3) |
| decoded_document = model.decode(document_embed) |
|
|
| |
| |
| |
| |
| expected_tokens = { |
| "ny": {"query": 5.7729, "document": 1.4109}, |
| "weather": {"query": 4.5684, "document": 1.4673}, |
| "now": {"query": 3.5895, "document": 0.7473}, |
| } |
|
|
| query_token_scores = {token: score for token, score in decoded_query} |
| document_token_scores = {token: score for token, score in decoded_document} |
|
|
| for token, expected in expected_tokens.items(): |
| assert token in query_token_scores, f"Token '{token}' not found in query scores" |
| assert token in document_token_scores, f"Token '{token}' not found in document scores" |
|
|
| query_score = query_token_scores[token] |
| document_score = document_token_scores[token] |
|
|
| assert abs(query_score - expected["query"]) < tolerance, ( |
| f"Query score for '{token}': expected {expected['query']}, got {query_score}" |
| ) |
| assert abs(document_score - expected["document"]) < tolerance, ( |
| f"Document score for '{token}': expected {expected['document']}, got {document_score}" |
| ) |
|
|
|
|
| def test_opensearch_v3_distill_similarity(): |
| """Test OpenSearch v3 distill model produces expected similarity scores.""" |
| |
| doc_encoder = Transformer( |
| "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill", transformer_task="fill-mask" |
| ) |
| router = Router.for_query_document( |
| query_modules=[ |
| SparseStaticEmbedding.from_json( |
| "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill", |
| tokenizer=doc_encoder.tokenizer, |
| frozen=True, |
| ), |
| ], |
| document_modules=[ |
| doc_encoder, |
| SpladePooling(pooling_strategy="max", activation_function="log1p_relu"), |
| ], |
| ) |
|
|
| model = SparseEncoder( |
| modules=[router], |
| similarity_fn_name="dot", |
| ) |
|
|
| |
| query = "What's the weather in ny now?" |
| document = "Currently New York is rainy." |
|
|
| |
| query_embed = model.encode_query(query) |
| document_embed = model.encode_document(document) |
| similarity = model.similarity(query_embed, document_embed).cpu() |
|
|
| |
| expected_similarity = 11.1105 |
| tolerance = 1e-3 |
|
|
| |
| assert torch.allclose(similarity, torch.tensor([[expected_similarity]]), atol=tolerance, rtol=0.01), ( |
| f"Expected similarity ~{expected_similarity}, got {similarity.item():.4f}" |
| ) |
|
|
| |
| decoded_query = model.decode(query_embed, top_k=10) |
| decoded_document = model.decode(document_embed) |
|
|
| |
| |
| |
| |
| |
| |
| |
| expected_tokens = { |
| "ny": {"query": 5.7729, "document": 0.8049}, |
| "weather": {"query": 4.5684, "document": 0.9710}, |
| "now": {"query": 3.5895, "document": 0.4720}, |
| "?": {"query": 3.3313, "document": 0.0286}, |
| "what": {"query": 2.7699, "document": 0.0787}, |
| "in": {"query": 0.4989, "document": 0.0417}, |
| } |
|
|
| query_token_scores = {token: score for token, score in decoded_query} |
| document_token_scores = {token: score for token, score in decoded_document} |
|
|
| for token, expected in expected_tokens.items(): |
| assert token in query_token_scores, f"Token '{token}' not found in query scores" |
| assert token in document_token_scores, f"Token '{token}' not found in document scores" |
|
|
| query_score = query_token_scores[token] |
| document_score = document_token_scores[token] |
|
|
| assert abs(query_score - expected["query"]) < tolerance, ( |
| f"Query score for '{token}': expected {expected['query']}, got {query_score}" |
| ) |
| assert abs(document_score - expected["document"]) < tolerance, ( |
| f"Document score for '{token}': expected {expected['document']}, got {document_score}" |
| ) |
|
|