Upload backend/venv/lib/python3.10/site-packages/sentence_transformers/similarity_functions.py with huggingface_hub
Browse files
backend/venv/lib/python3.10/site-packages/sentence_transformers/similarity_functions.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import Callable
|
| 5 |
+
|
| 6 |
+
from numpy import ndarray
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
from .util import (
|
| 10 |
+
cos_sim,
|
| 11 |
+
dot_score,
|
| 12 |
+
euclidean_sim,
|
| 13 |
+
manhattan_sim,
|
| 14 |
+
pairwise_cos_sim,
|
| 15 |
+
pairwise_dot_score,
|
| 16 |
+
pairwise_euclidean_sim,
|
| 17 |
+
pairwise_manhattan_sim,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SimilarityFunction(Enum):
|
| 22 |
+
"""
|
| 23 |
+
Enum class for supported similarity functions. The following functions are supported:
|
| 24 |
+
|
| 25 |
+
- ``SimilarityFunction.COSINE`` (``"cosine"``): Cosine similarity
|
| 26 |
+
- ``SimilarityFunction.DOT_PRODUCT`` (``"dot"``, ``dot_product``): Dot product similarity
|
| 27 |
+
- ``SimilarityFunction.EUCLIDEAN`` (``"euclidean"``): Euclidean distance
|
| 28 |
+
- ``SimilarityFunction.MANHATTAN`` (``"manhattan"``): Manhattan distance
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
COSINE = "cosine"
|
| 32 |
+
DOT_PRODUCT = "dot"
|
| 33 |
+
DOT = "dot" # Alias for DOT_PRODUCT
|
| 34 |
+
EUCLIDEAN = "euclidean"
|
| 35 |
+
MANHATTAN = "manhattan"
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def to_similarity_fn(
|
| 39 |
+
similarity_function: str | SimilarityFunction,
|
| 40 |
+
) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
|
| 41 |
+
"""
|
| 42 |
+
Converts a similarity function name or enum value to the corresponding similarity function.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
similarity_function (Union[str, SimilarityFunction]): The name or enum value of the similarity function.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]: The corresponding similarity function.
|
| 49 |
+
|
| 50 |
+
Raises:
|
| 51 |
+
ValueError: If the provided function is not supported.
|
| 52 |
+
|
| 53 |
+
Example:
|
| 54 |
+
>>> similarity_fn = SimilarityFunction.to_similarity_fn("cosine")
|
| 55 |
+
>>> similarity_scores = similarity_fn(embeddings1, embeddings2)
|
| 56 |
+
>>> similarity_scores
|
| 57 |
+
tensor([[0.3952, 0.0554],
|
| 58 |
+
[0.0992, 0.1570]])
|
| 59 |
+
"""
|
| 60 |
+
similarity_function = SimilarityFunction(similarity_function)
|
| 61 |
+
|
| 62 |
+
if similarity_function == SimilarityFunction.COSINE:
|
| 63 |
+
return cos_sim
|
| 64 |
+
if similarity_function == SimilarityFunction.DOT_PRODUCT:
|
| 65 |
+
return dot_score
|
| 66 |
+
if similarity_function == SimilarityFunction.MANHATTAN:
|
| 67 |
+
return manhattan_sim
|
| 68 |
+
if similarity_function == SimilarityFunction.EUCLIDEAN:
|
| 69 |
+
return euclidean_sim
|
| 70 |
+
|
| 71 |
+
raise ValueError(
|
| 72 |
+
f"The provided function {similarity_function} is not supported. Use one of the supported values: {SimilarityFunction.possible_values()}."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
@staticmethod
|
| 76 |
+
def to_similarity_pairwise_fn(
|
| 77 |
+
similarity_function: str | SimilarityFunction,
|
| 78 |
+
) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
|
| 79 |
+
"""
|
| 80 |
+
Converts a similarity function into a pairwise similarity function.
|
| 81 |
+
|
| 82 |
+
The pairwise similarity function returns the diagonal vector from the similarity matrix, i.e. it only
|
| 83 |
+
computes the similarity(a[i], b[i]) for each i in the range of the input tensors, rather than
|
| 84 |
+
computing the similarity between all pairs of a and b.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
similarity_function (Union[str, SimilarityFunction]): The name or enum value of the similarity function.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]: The pairwise similarity function.
|
| 91 |
+
|
| 92 |
+
Raises:
|
| 93 |
+
ValueError: If the provided similarity function is not supported.
|
| 94 |
+
|
| 95 |
+
Example:
|
| 96 |
+
>>> pairwise_fn = SimilarityFunction.to_similarity_pairwise_fn("cosine")
|
| 97 |
+
>>> similarity_scores = pairwise_fn(embeddings1, embeddings2)
|
| 98 |
+
>>> similarity_scores
|
| 99 |
+
tensor([0.3952, 0.1570])
|
| 100 |
+
"""
|
| 101 |
+
similarity_function = SimilarityFunction(similarity_function)
|
| 102 |
+
|
| 103 |
+
if similarity_function == SimilarityFunction.COSINE:
|
| 104 |
+
return pairwise_cos_sim
|
| 105 |
+
if similarity_function == SimilarityFunction.DOT_PRODUCT:
|
| 106 |
+
return pairwise_dot_score
|
| 107 |
+
if similarity_function == SimilarityFunction.MANHATTAN:
|
| 108 |
+
return pairwise_manhattan_sim
|
| 109 |
+
if similarity_function == SimilarityFunction.EUCLIDEAN:
|
| 110 |
+
return pairwise_euclidean_sim
|
| 111 |
+
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"The provided function {similarity_function} is not supported. Use one of the supported values: {SimilarityFunction.possible_values()}."
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def possible_values() -> list[str]:
|
| 118 |
+
"""
|
| 119 |
+
Returns a list of possible values for the SimilarityFunction enum.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
list: A list of possible values for the SimilarityFunction enum.
|
| 123 |
+
|
| 124 |
+
Example:
|
| 125 |
+
>>> possible_values = SimilarityFunction.possible_values()
|
| 126 |
+
>>> possible_values
|
| 127 |
+
['cosine', 'dot', 'euclidean', 'manhattan']
|
| 128 |
+
"""
|
| 129 |
+
return [m.value for m in SimilarityFunction]
|