dzenzzz commited on
Commit
321de76
·
verified ·
1 Parent(s): 5baa37e

Create neural_searcher.py

Browse files
Files changed (1) hide show
  1. neural_searcher.py +42 -0
neural_searcher.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient
2
+ from fastembed import SparseTextEmbedding
3
+ from qdrant_client import QdrantClient, models
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ # from config import API_KEY,HOST,DENSE_MODEL,SPARSE_MODEL,DENSE_MODEL_SHORT,SPARSE_MODEL_SHORT
7
+ class NeuralSearcher:
8
+
9
+ def __init__(self, collection_name):
10
+ self.collection_name = collection_name
11
+ self.dense_model = SentenceTransformer("djovak/embedic-small",device="cpu")
12
+ self.sparse_model = SparseTextEmbedding("Qdrant/bm25")
13
+ self.qdrant_client = QdrantClient("http://localhost:6333/",api_key="")
14
+
15
+ def search(self, text: str):
16
+
17
+ dense_query = self.dense_model.encode(text).tolist()
18
+ sparse_query = self.sparse_model.query_embed(text)
19
+
20
+ # Use `vector` for search for closest vectors in the collection
21
+ search_result = self.qdrant_client.query_points(
22
+ collection_name= self.collection_name,
23
+ prefetch=[
24
+ models.Prefetch(
25
+ query=dense_query,
26
+ using="djovak/embedic-small",
27
+ limit=5
28
+ ),
29
+ models.Prefetch(
30
+ query=next(sparse_query).as_object(),
31
+ using="Qdrant/bm25",
32
+ limit=5
33
+ )
34
+ ],
35
+ query=models.FusionQuery(
36
+ fusion=models.Fusion.RRF
37
+ ),
38
+ limit = 9
39
+ ).points
40
+
41
+ payloads = [hit.payload for hit in search_result]
42
+ return payloads