davidtran999 commited on
Commit
c6818a9
·
verified ·
1 Parent(s): 81b824e

Upload backend/venv/lib/python3.10/site-packages/sentence_transformers/quantization.py with huggingface_hub

Browse files
backend/venv/lib/python3.10/site-packages/sentence_transformers/quantization.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import time
5
+ from typing import TYPE_CHECKING, Literal
6
+
7
+ import numpy as np
8
+ from torch import Tensor
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ if TYPE_CHECKING:
14
+ import faiss
15
+ import usearch
16
+
17
+
18
+ def semantic_search_faiss(
19
+ query_embeddings: np.ndarray,
20
+ corpus_embeddings: np.ndarray | None = None,
21
+ corpus_index: faiss.Index | None = None,
22
+ corpus_precision: Literal["float32", "uint8", "ubinary"] = "float32",
23
+ top_k: int = 10,
24
+ ranges: np.ndarray | None = None,
25
+ calibration_embeddings: np.ndarray | None = None,
26
+ rescore: bool = True,
27
+ rescore_multiplier: int = 2,
28
+ exact: bool = True,
29
+ output_index: bool = False,
30
+ ) -> tuple[list[list[dict[str, int | float]]], float, faiss.Index]:
31
+ """
32
+ Performs semantic search using the FAISS library.
33
+
34
+ Rescoring will be performed if:
35
+ 1. `rescore` is True
36
+ 2. The query embeddings are not quantized
37
+ 3. The corpus is quantized, i.e. the corpus precision is not float32
38
+ Only if these conditions are true, will we search for `top_k * rescore_multiplier` samples and then rescore to only
39
+ keep `top_k`.
40
+
41
+ Args:
42
+ query_embeddings: Embeddings of the query sentences. Ideally not
43
+ quantized to allow for rescoring.
44
+ corpus_embeddings: Embeddings of the corpus sentences. Either
45
+ `corpus_embeddings` or `corpus_index` should be used, not
46
+ both. The embeddings can be quantized to "int8" or "binary"
47
+ for more efficient search.
48
+ corpus_index: FAISS index for the corpus sentences. Either
49
+ `corpus_embeddings` or `corpus_index` should be used, not
50
+ both.
51
+ corpus_precision: Precision of the corpus embeddings. The
52
+ options are "float32", "int8", or "binary". Default is
53
+ "float32".
54
+ top_k: Number of top results to retrieve. Default is 10.
55
+ ranges: Ranges for quantization of embeddings. This is only used
56
+ for int8 quantization, where the ranges refers to the
57
+ minimum and maximum values for each dimension. So, it's a 2D
58
+ array with shape (2, embedding_dim). Default is None, which
59
+ means that the ranges will be calculated from the
60
+ calibration embeddings.
61
+ calibration_embeddings: Embeddings used for calibration during
62
+ quantization. This is only used for int8 quantization, where
63
+ the calibration embeddings can be used to compute ranges,
64
+ i.e. the minimum and maximum values for each dimension.
65
+ Default is None, which means that the ranges will be
66
+ calculated from the query embeddings. This is not
67
+ recommended.
68
+ rescore: Whether to perform rescoring. Note that rescoring still
69
+ will only be used if the query embeddings are not quantized
70
+ and the corpus is quantized, i.e. the corpus precision is
71
+ not "float32". Default is True.
72
+ rescore_multiplier: Oversampling factor for rescoring. The code
73
+ will now search `top_k * rescore_multiplier` samples and
74
+ then rescore to only keep `top_k`. Default is 2.
75
+ exact: Whether to use exact search or approximate search.
76
+ Default is True.
77
+ output_index: Whether to output the FAISS index used for the
78
+ search. Default is False.
79
+
80
+ Returns:
81
+ A tuple containing a list of search results and the time taken
82
+ for the search. If `output_index` is True, the tuple will also
83
+ contain the FAISS index used for the search.
84
+
85
+ Raises:
86
+ ValueError: If both `corpus_embeddings` and `corpus_index` are
87
+ provided or if neither is provided.
88
+
89
+ The list of search results is in the format: [[{"corpus_id": int, "score": float}, ...], ...]
90
+ The time taken for the search is a float value.
91
+ """
92
+ import faiss
93
+
94
+ if corpus_embeddings is not None and corpus_index is not None:
95
+ raise ValueError("Only corpus_embeddings or corpus_index should be used, not both.")
96
+ if corpus_embeddings is None and corpus_index is None:
97
+ raise ValueError("Either corpus_embeddings or corpus_index should be used.")
98
+
99
+ # If corpus_index is not provided, create a new index
100
+ if corpus_index is None:
101
+ if corpus_precision in ("float32", "uint8"):
102
+ if exact:
103
+ corpus_index = faiss.IndexFlatIP(corpus_embeddings.shape[1])
104
+ else:
105
+ corpus_index = faiss.IndexHNSWFlat(corpus_embeddings.shape[1], 16)
106
+
107
+ elif corpus_precision == "ubinary":
108
+ if exact:
109
+ corpus_index = faiss.IndexBinaryFlat(corpus_embeddings.shape[1] * 8)
110
+ else:
111
+ corpus_index = faiss.IndexBinaryHNSW(corpus_embeddings.shape[1] * 8, 16)
112
+
113
+ corpus_index.add(corpus_embeddings)
114
+
115
+ # If rescoring is enabled and the query embeddings are in float32, we need to quantize them
116
+ # to the same precision as the corpus embeddings. Also update the top_k value to account for the
117
+ # rescore_multiplier
118
+ rescore_embeddings = None
119
+ k = top_k
120
+ if query_embeddings.dtype not in (np.uint8, np.int8):
121
+ if rescore:
122
+ if corpus_precision != "float32":
123
+ rescore_embeddings = query_embeddings
124
+ k *= rescore_multiplier
125
+ else:
126
+ logger.warning(
127
+ "Rescoring is enabled but the corpus is not quantized. Either pass `rescore=False` or "
128
+ 'quantize the corpus embeddings with `quantize_embeddings(embeddings, precision="...") `'
129
+ 'and pass `corpus_precision="..."` to `semantic_search_faiss`.'
130
+ )
131
+
132
+ query_embeddings = quantize_embeddings(
133
+ query_embeddings,
134
+ precision=corpus_precision,
135
+ ranges=ranges,
136
+ calibration_embeddings=calibration_embeddings,
137
+ )
138
+ elif rescore:
139
+ logger.warning(
140
+ "Rescoring is enabled but the query embeddings are quantized. Either pass `rescore=False` or don't quantize the query embeddings."
141
+ )
142
+
143
+ # Perform the search using the usearch index
144
+ start_t = time.time()
145
+ scores, indices = corpus_index.search(query_embeddings, k)
146
+
147
+ # If rescoring is enabled, we need to rescore the results using the rescore_embeddings
148
+ if rescore_embeddings is not None:
149
+ top_k_embeddings = np.array(
150
+ [[corpus_index.reconstruct(idx.item()) for idx in query_indices] for query_indices in indices]
151
+ )
152
+ # If the corpus precision is binary, we need to unpack the bits
153
+ if corpus_precision == "ubinary":
154
+ top_k_embeddings = np.unpackbits(top_k_embeddings, axis=-1).astype(int)
155
+ else:
156
+ top_k_embeddings = top_k_embeddings.astype(int)
157
+
158
+ # rescore_embeddings: [num_queries, embedding_dim]
159
+ # top_k_embeddings: [num_queries, top_k, embedding_dim]
160
+ # updated_scores: [num_queries, top_k]
161
+ # We use einsum to calculate the dot product between the query and the top_k embeddings, equivalent to looping
162
+ # over the queries and calculating 'rescore_embeddings[i] @ top_k_embeddings[i].T'
163
+ rescored_scores = np.einsum("ij,ikj->ik", rescore_embeddings, top_k_embeddings)
164
+ rescored_indices = np.argsort(-rescored_scores)[:, :top_k]
165
+ indices = indices[np.arange(len(query_embeddings))[:, None], rescored_indices]
166
+ scores = rescored_scores[np.arange(len(query_embeddings))[:, None], rescored_indices]
167
+
168
+ delta_t = time.time() - start_t
169
+
170
+ outputs = (
171
+ [
172
+ [
173
+ {"corpus_id": int(neighbor), "score": float(score)}
174
+ for score, neighbor in zip(scores[query_id], indices[query_id])
175
+ ]
176
+ for query_id in range(len(query_embeddings))
177
+ ],
178
+ delta_t,
179
+ )
180
+ if output_index:
181
+ outputs = (*outputs, corpus_index)
182
+ return outputs
183
+
184
+
185
+ def semantic_search_usearch(
186
+ query_embeddings: np.ndarray,
187
+ corpus_embeddings: np.ndarray | None = None,
188
+ corpus_index: usearch.index.Index | None = None,
189
+ corpus_precision: Literal["float32", "int8", "binary"] = "float32",
190
+ top_k: int = 10,
191
+ ranges: np.ndarray | None = None,
192
+ calibration_embeddings: np.ndarray | None = None,
193
+ rescore: bool = True,
194
+ rescore_multiplier: int = 2,
195
+ exact: bool = True,
196
+ output_index: bool = False,
197
+ ) -> tuple[list[list[dict[str, int | float]]], float, usearch.index.Index]:
198
+ """
199
+ Performs semantic search using the usearch library.
200
+
201
+ Rescoring will be performed if:
202
+ 1. `rescore` is True
203
+ 2. The query embeddings are not quantized
204
+ 3. The corpus is quantized, i.e. the corpus precision is not float32
205
+ Only if these conditions are true, will we search for `top_k * rescore_multiplier` samples and then rescore to only
206
+ keep `top_k`.
207
+
208
+ Args:
209
+ query_embeddings: Embeddings of the query sentences. Ideally not
210
+ quantized to allow for rescoring.
211
+ corpus_embeddings: Embeddings of the corpus sentences. Either
212
+ `corpus_embeddings` or `corpus_index` should be used, not
213
+ both. The embeddings can be quantized to "int8" or "binary"
214
+ for more efficient search.
215
+ corpus_index: usearch index for the corpus sentences. Either
216
+ `corpus_embeddings` or `corpus_index` should be used, not
217
+ both.
218
+ corpus_precision: Precision of the corpus embeddings. The
219
+ options are "float32", "int8", or "binary". Default is
220
+ "float32".
221
+ top_k: Number of top results to retrieve. Default is 10.
222
+ ranges: Ranges for quantization of embeddings. This is only used
223
+ for int8 quantization, where the ranges refers to the
224
+ minimum and maximum values for each dimension. So, it's a 2D
225
+ array with shape (2, embedding_dim). Default is None, which
226
+ means that the ranges will be calculated from the
227
+ calibration embeddings.
228
+ calibration_embeddings: Embeddings used for calibration during
229
+ quantization. This is only used for int8 quantization, where
230
+ the calibration embeddings can be used to compute ranges,
231
+ i.e. the minimum and maximum values for each dimension.
232
+ Default is None, which means that the ranges will be
233
+ calculated from the query embeddings. This is not
234
+ recommended.
235
+ rescore: Whether to perform rescoring. Note that rescoring still
236
+ will only be used if the query embeddings are not quantized
237
+ and the corpus is quantized, i.e. the corpus precision is
238
+ not "float32". Default is True.
239
+ rescore_multiplier: Oversampling factor for rescoring. The code
240
+ will now search `top_k * rescore_multiplier` samples and
241
+ then rescore to only keep `top_k`. Default is 2.
242
+ exact: Whether to use exact search or approximate search.
243
+ Default is True.
244
+ output_index: Whether to output the usearch index used for the
245
+ search. Default is False.
246
+
247
+ Returns:
248
+ A tuple containing a list of search results and the time taken
249
+ for the search. If `output_index` is True, the tuple will also
250
+ contain the usearch index used for the search.
251
+
252
+ Raises:
253
+ ValueError: If both `corpus_embeddings` and `corpus_index` are
254
+ provided or if neither is provided.
255
+
256
+ The list of search results is in the format: [[{"corpus_id": int, "score": float}, ...], ...]
257
+ The time taken for the search is a float value.
258
+ """
259
+ from usearch.compiled import ScalarKind
260
+ from usearch.index import Index
261
+
262
+ if corpus_embeddings is not None and corpus_index is not None:
263
+ raise ValueError("Only corpus_embeddings or corpus_index should be used, not both.")
264
+ if corpus_embeddings is None and corpus_index is None:
265
+ raise ValueError("Either corpus_embeddings or corpus_index should be used.")
266
+ if corpus_precision not in ["float32", "int8", "binary"]:
267
+ raise ValueError('corpus_precision must be "float32", "int8", or "binary" for usearch')
268
+
269
+ # If corpus_index is not provided, create a new index
270
+ if corpus_index is None:
271
+ if corpus_precision == "float32":
272
+ corpus_index = Index(
273
+ ndim=corpus_embeddings.shape[1],
274
+ metric="cos",
275
+ dtype="f32",
276
+ )
277
+ elif corpus_precision == "int8":
278
+ corpus_index = Index(
279
+ ndim=corpus_embeddings.shape[1],
280
+ metric="ip",
281
+ dtype="i8",
282
+ )
283
+ elif corpus_precision == "binary":
284
+ corpus_index = Index(
285
+ ndim=corpus_embeddings.shape[1],
286
+ metric="hamming",
287
+ dtype="b1",
288
+ )
289
+ corpus_index.add(np.arange(len(corpus_embeddings)), corpus_embeddings)
290
+
291
+ # If rescoring is enabled and the query embeddings are in float32, we need to quantize them
292
+ # to the same precision as the corpus embeddings. Also update the top_k value to account for the
293
+ # rescore_multiplier
294
+ rescore_embeddings = None
295
+ k = top_k
296
+ if query_embeddings.dtype not in (np.uint8, np.int8):
297
+ if rescore:
298
+ if corpus_index.dtype != ScalarKind.F32:
299
+ rescore_embeddings = query_embeddings
300
+ k *= rescore_multiplier
301
+ else:
302
+ logger.warning(
303
+ "Rescoring is enabled but the corpus is not quantized. Either pass `rescore=False` or "
304
+ 'quantize the corpus embeddings with `quantize_embeddings(embeddings, precision="...") `'
305
+ 'and pass `corpus_precision="..."` to `semantic_search_usearch`.'
306
+ )
307
+
308
+ query_embeddings = quantize_embeddings(
309
+ query_embeddings,
310
+ precision=corpus_precision,
311
+ ranges=ranges,
312
+ calibration_embeddings=calibration_embeddings,
313
+ )
314
+ elif rescore:
315
+ logger.warning(
316
+ "Rescoring is enabled but the query embeddings are quantized. Either pass `rescore=False` or don't quantize the query embeddings."
317
+ )
318
+
319
+ # Perform the search using the usearch index
320
+ start_t = time.time()
321
+ matches = corpus_index.search(query_embeddings, count=k, exact=exact)
322
+ scores = matches.distances
323
+ indices = matches.keys
324
+
325
+ if scores.ndim < 2:
326
+ scores = np.atleast_2d(scores)
327
+ if indices.ndim < 2:
328
+ indices = np.atleast_2d(indices)
329
+
330
+ # If rescoring is enabled, we need to rescore the results using the rescore_embeddings
331
+ if rescore_embeddings is not None:
332
+ top_k_embeddings = np.array([corpus_index.get(query_indices) for query_indices in indices])
333
+ # If the corpus precision is binary, we need to unpack the bits
334
+ if corpus_precision == "binary":
335
+ top_k_embeddings = np.unpackbits(top_k_embeddings.astype(np.uint8), axis=-1)
336
+ top_k_embeddings = top_k_embeddings.astype(int)
337
+
338
+ # rescore_embeddings: [num_queries, embedding_dim]
339
+ # top_k_embeddings: [num_queries, top_k, embedding_dim]
340
+ # updated_scores: [num_queries, top_k]
341
+ # We use einsum to calculate the dot product between the query and the top_k embeddings, equivalent to looping
342
+ # over the queries and calculating 'rescore_embeddings[i] @ top_k_embeddings[i].T'
343
+ rescored_scores = np.einsum("ij,ikj->ik", rescore_embeddings, top_k_embeddings)
344
+ rescored_indices = np.argsort(-rescored_scores)[:, :top_k]
345
+ indices = indices[np.arange(len(query_embeddings))[:, None], rescored_indices]
346
+ scores = rescored_scores[np.arange(len(query_embeddings))[:, None], rescored_indices]
347
+
348
+ delta_t = time.time() - start_t
349
+
350
+ outputs = (
351
+ [
352
+ [
353
+ {"corpus_id": int(neighbor), "score": float(score)}
354
+ for score, neighbor in zip(scores[query_id], indices[query_id])
355
+ ]
356
+ for query_id in range(len(query_embeddings))
357
+ ],
358
+ delta_t,
359
+ )
360
+ if output_index:
361
+ outputs = (*outputs, corpus_index)
362
+ return outputs
363
+
364
+
365
+ def quantize_embeddings(
366
+ embeddings: Tensor | np.ndarray,
367
+ precision: Literal["float32", "int8", "uint8", "binary", "ubinary"],
368
+ ranges: np.ndarray | None = None,
369
+ calibration_embeddings: np.ndarray | None = None,
370
+ ) -> np.ndarray:
371
+ """
372
+ Quantizes embeddings to a lower precision. This can be used to reduce the memory footprint and increase the
373
+ speed of similarity search. The supported precisions are "float32", "int8", "uint8", "binary", and "ubinary".
374
+
375
+ Args:
376
+ embeddings: Unquantized (e.g. float) embeddings with to quantize
377
+ to a given precision
378
+ precision: The precision to convert to. Options are "float32",
379
+ "int8", "uint8", "binary", "ubinary".
380
+ ranges (Optional[np.ndarray]): Ranges for quantization of
381
+ embeddings. This is only used for int8 quantization, where
382
+ the ranges refers to the minimum and maximum values for each
383
+ dimension. So, it's a 2D array with shape (2,
384
+ embedding_dim). Default is None, which means that the ranges
385
+ will be calculated from the calibration embeddings.
386
+ calibration_embeddings (Optional[np.ndarray]): Embeddings used
387
+ for calibration during quantization. This is only used for
388
+ int8 quantization, where the calibration embeddings can be
389
+ used to compute ranges, i.e. the minimum and maximum values
390
+ for each dimension. Default is None, which means that the
391
+ ranges will be calculated from the query embeddings. This is
392
+ not recommended.
393
+
394
+ Returns:
395
+ Quantized embeddings with the specified precision
396
+ """
397
+ if isinstance(embeddings, Tensor):
398
+ embeddings = embeddings.cpu().numpy()
399
+ elif isinstance(embeddings, list):
400
+ if isinstance(embeddings[0], Tensor):
401
+ embeddings = [embedding.cpu().numpy() for embedding in embeddings]
402
+ embeddings = np.array(embeddings)
403
+ if embeddings.dtype in (np.uint8, np.int8):
404
+ raise Exception("Embeddings to quantize must be float rather than int8 or uint8.")
405
+
406
+ if precision == "float32":
407
+ return embeddings.astype(np.float32)
408
+
409
+ if precision.endswith("int8"):
410
+ # Either use the 1. provided ranges, 2. the calibration dataset or 3. the provided embeddings
411
+ if ranges is None:
412
+ if calibration_embeddings is not None:
413
+ ranges = np.vstack((np.min(calibration_embeddings, axis=0), np.max(calibration_embeddings, axis=0)))
414
+ else:
415
+ if embeddings.shape[0] < 100:
416
+ logger.warning(
417
+ f"Computing {precision} quantization buckets based on {len(embeddings)} embedding{'s' if len(embeddings) != 1 else ''}."
418
+ f" {precision} quantization is more stable with `ranges` calculated from more embeddings "
419
+ "or a `calibration_embeddings` that can be used to calculate the buckets."
420
+ )
421
+ ranges = np.vstack((np.min(embeddings, axis=0), np.max(embeddings, axis=0)))
422
+ starts = ranges[0, :]
423
+ steps = (ranges[1, :] - ranges[0, :]) / 255
424
+
425
+ if precision == "uint8":
426
+ return ((embeddings - starts) / steps).astype(np.uint8)
427
+ elif precision == "int8":
428
+ return ((embeddings - starts) / steps - 128).astype(np.int8)
429
+
430
+ if precision == "binary":
431
+ return (np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1) - 128).astype(np.int8)
432
+
433
+ if precision == "ubinary":
434
+ return np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1)
435
+
436
+ raise ValueError(f"Precision {precision} is not supported")