dzenzzz commited on
Commit
52b95eb
·
verified ·
1 Parent(s): 9ed9588

Create doc_searcher.py

Browse files
Files changed (1) hide show
  1. doc_searcher.py +65 -0
doc_searcher.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient
2
+ from fastembed import SparseTextEmbedding, LateInteractionTextEmbedding
3
+ from qdrant_client import QdrantClient, models
4
+ from sentence_transformers import SentenceTransformer
5
+ from config import DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY,HUGGING_FACE_API_KEY
6
+
7
+ class DocSearcher:
8
+
9
+ def __init__(self, collection_name):
10
+ self.collection_name = collection_name
11
+ self.dense_model = SentenceTransformer(DENSE_MODEL,device="cpu",token=HUGGING_FACE_API_KEY)
12
+ self.sparse_model = SparseTextEmbedding(SPARSE_MODEL)
13
+ self.late_interaction_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL)
14
+ self.qdrant_client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY,timeout=30)
15
+
16
+ async def search(self, text: str):
17
+
18
+ dense_query = self.dense_model.encode(text).tolist()
19
+ sparse_query = next(self.sparse_model.query_embed(text))
20
+
21
+ prefetch = [
22
+ models.Prefetch(
23
+ query=dense_query,
24
+ using=DENSE_MODEL,
25
+ params=models.SearchParams(
26
+ quantization=models.QuantizationSearchParams(
27
+ rescore=False,
28
+ ),
29
+ ),
30
+ limit=200
31
+ ),
32
+ models.Prefetch(
33
+ query=models.SparseVector(**sparse_query.as_object()),
34
+ using=SPARSE_MODEL,
35
+ params=models.SearchParams(
36
+ quantization=models.QuantizationSearchParams(
37
+ rescore=False,
38
+ ),
39
+ ),
40
+ limit=200
41
+ )
42
+ ]
43
+
44
+ search_result = self.qdrant_client.query_points(
45
+ collection_name= self.collection_name,
46
+ search_params=models.SearchParams(
47
+ hnsw_ef=128,
48
+ quantization=models.QuantizationSearchParams(
49
+ rescore=True,
50
+ ),
51
+ ),
52
+ prefetch=prefetch,
53
+ query=models.FusionQuery(
54
+ fusion=models.Fusion.RRF,
55
+ ),
56
+ with_payload=True,
57
+ limit = 10
58
+ ).points
59
+
60
+ data = []
61
+
62
+ for hit in search_result:
63
+ data.append(hit.payload)
64
+
65
+ return data