davidtran999 commited on
Commit
a77e733
·
verified ·
1 Parent(s): c71e75d

Upload backend/venv/lib/python3.10/site-packages/sentence_transformers/similarity_functions.py with huggingface_hub

Browse files
backend/venv/lib/python3.10/site-packages/sentence_transformers/similarity_functions.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import Callable
5
+
6
+ from numpy import ndarray
7
+ from torch import Tensor
8
+
9
+ from .util import (
10
+ cos_sim,
11
+ dot_score,
12
+ euclidean_sim,
13
+ manhattan_sim,
14
+ pairwise_cos_sim,
15
+ pairwise_dot_score,
16
+ pairwise_euclidean_sim,
17
+ pairwise_manhattan_sim,
18
+ )
19
+
20
+
21
+ class SimilarityFunction(Enum):
22
+ """
23
+ Enum class for supported similarity functions. The following functions are supported:
24
+
25
+ - ``SimilarityFunction.COSINE`` (``"cosine"``): Cosine similarity
26
+ - ``SimilarityFunction.DOT_PRODUCT`` (``"dot"``, ``dot_product``): Dot product similarity
27
+ - ``SimilarityFunction.EUCLIDEAN`` (``"euclidean"``): Euclidean distance
28
+ - ``SimilarityFunction.MANHATTAN`` (``"manhattan"``): Manhattan distance
29
+ """
30
+
31
+ COSINE = "cosine"
32
+ DOT_PRODUCT = "dot"
33
+ DOT = "dot" # Alias for DOT_PRODUCT
34
+ EUCLIDEAN = "euclidean"
35
+ MANHATTAN = "manhattan"
36
+
37
+ @staticmethod
38
+ def to_similarity_fn(
39
+ similarity_function: str | SimilarityFunction,
40
+ ) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
41
+ """
42
+ Converts a similarity function name or enum value to the corresponding similarity function.
43
+
44
+ Args:
45
+ similarity_function (Union[str, SimilarityFunction]): The name or enum value of the similarity function.
46
+
47
+ Returns:
48
+ Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]: The corresponding similarity function.
49
+
50
+ Raises:
51
+ ValueError: If the provided function is not supported.
52
+
53
+ Example:
54
+ >>> similarity_fn = SimilarityFunction.to_similarity_fn("cosine")
55
+ >>> similarity_scores = similarity_fn(embeddings1, embeddings2)
56
+ >>> similarity_scores
57
+ tensor([[0.3952, 0.0554],
58
+ [0.0992, 0.1570]])
59
+ """
60
+ similarity_function = SimilarityFunction(similarity_function)
61
+
62
+ if similarity_function == SimilarityFunction.COSINE:
63
+ return cos_sim
64
+ if similarity_function == SimilarityFunction.DOT_PRODUCT:
65
+ return dot_score
66
+ if similarity_function == SimilarityFunction.MANHATTAN:
67
+ return manhattan_sim
68
+ if similarity_function == SimilarityFunction.EUCLIDEAN:
69
+ return euclidean_sim
70
+
71
+ raise ValueError(
72
+ f"The provided function {similarity_function} is not supported. Use one of the supported values: {SimilarityFunction.possible_values()}."
73
+ )
74
+
75
+ @staticmethod
76
+ def to_similarity_pairwise_fn(
77
+ similarity_function: str | SimilarityFunction,
78
+ ) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
79
+ """
80
+ Converts a similarity function into a pairwise similarity function.
81
+
82
+ The pairwise similarity function returns the diagonal vector from the similarity matrix, i.e. it only
83
+ computes the similarity(a[i], b[i]) for each i in the range of the input tensors, rather than
84
+ computing the similarity between all pairs of a and b.
85
+
86
+ Args:
87
+ similarity_function (Union[str, SimilarityFunction]): The name or enum value of the similarity function.
88
+
89
+ Returns:
90
+ Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]: The pairwise similarity function.
91
+
92
+ Raises:
93
+ ValueError: If the provided similarity function is not supported.
94
+
95
+ Example:
96
+ >>> pairwise_fn = SimilarityFunction.to_similarity_pairwise_fn("cosine")
97
+ >>> similarity_scores = pairwise_fn(embeddings1, embeddings2)
98
+ >>> similarity_scores
99
+ tensor([0.3952, 0.1570])
100
+ """
101
+ similarity_function = SimilarityFunction(similarity_function)
102
+
103
+ if similarity_function == SimilarityFunction.COSINE:
104
+ return pairwise_cos_sim
105
+ if similarity_function == SimilarityFunction.DOT_PRODUCT:
106
+ return pairwise_dot_score
107
+ if similarity_function == SimilarityFunction.MANHATTAN:
108
+ return pairwise_manhattan_sim
109
+ if similarity_function == SimilarityFunction.EUCLIDEAN:
110
+ return pairwise_euclidean_sim
111
+
112
+ raise ValueError(
113
+ f"The provided function {similarity_function} is not supported. Use one of the supported values: {SimilarityFunction.possible_values()}."
114
+ )
115
+
116
+ @staticmethod
117
+ def possible_values() -> list[str]:
118
+ """
119
+ Returns a list of possible values for the SimilarityFunction enum.
120
+
121
+ Returns:
122
+ list: A list of possible values for the SimilarityFunction enum.
123
+
124
+ Example:
125
+ >>> possible_values = SimilarityFunction.possible_values()
126
+ >>> possible_values
127
+ ['cosine', 'dot', 'euclidean', 'manhattan']
128
+ """
129
+ return [m.value for m in SimilarityFunction]