msh2481
commited on
Commit
·
ef93aa3
1
Parent(s):
82b3a71
Step 3
Browse files- backend/semantic_search.py +15 -7
backend/semantic_search.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
-
import lancedb
|
| 2 |
import os
|
| 3 |
-
import gradio as gr
|
| 4 |
-
from sentence_transformers import SentenceTransformer
|
| 5 |
|
| 6 |
|
| 7 |
db = lancedb.connect(".lancedb")
|
|
@@ -10,17 +10,25 @@ TABLE = db.open_table(os.getenv("TABLE_NAME"))
|
|
| 10 |
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
|
| 11 |
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
| 12 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
|
|
|
| 13 |
|
| 14 |
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
-
def retrieve(query, k):
|
| 18 |
query_vec = retriever.encode(query)
|
| 19 |
try:
|
| 20 |
-
documents =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
documents = [doc[TEXT_COLUMN] for doc in documents]
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
|
| 25 |
except Exception as e:
|
| 26 |
raise gr.Error(str(e))
|
|
|
|
| 1 |
+
import lancedb # type: ignore
|
| 2 |
import os
|
| 3 |
+
import gradio as gr # type: ignore
|
| 4 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder # type: ignore
|
| 5 |
|
| 6 |
|
| 7 |
db = lancedb.connect(".lancedb")
|
|
|
|
| 10 |
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
|
| 11 |
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
| 12 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
| 13 |
+
RERANKER = os.getenv("RERANKER", "cross-encoder/ms-marco-MiniLM-L-6-v2")
|
| 14 |
|
| 15 |
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
| 16 |
+
reranker = CrossEncoder(RERANKER)
|
| 17 |
|
| 18 |
|
| 19 |
+
def retrieve(query, k, rerank_factor=3):
|
| 20 |
query_vec = retriever.encode(query)
|
| 21 |
try:
|
| 22 |
+
documents = (
|
| 23 |
+
TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN)
|
| 24 |
+
.limit(k * rerank_factor)
|
| 25 |
+
.to_list()
|
| 26 |
+
)
|
| 27 |
documents = [doc[TEXT_COLUMN] for doc in documents]
|
| 28 |
+
scores = reranker.predict([(query, doc) for doc in documents])
|
| 29 |
+
best_scores_and_documents = sorted(zip(scores, documents), reverse=True)[:k]
|
| 30 |
+
best_documents = [doc[1] for doc in best_scores_and_documents]
|
| 31 |
+
return best_documents
|
| 32 |
|
| 33 |
except Exception as e:
|
| 34 |
raise gr.Error(str(e))
|