Spaces:
Sleeping
Sleeping
- src/retrieval/reranker.py +7 -10
src/retrieval/reranker.py
CHANGED
|
@@ -10,16 +10,19 @@ class HybridReranker:
|
|
| 10 |
self,
|
| 11 |
vector_store: FAISS,
|
| 12 |
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 13 |
-
device: str = 'cpu'
|
|
|
|
| 14 |
):
|
| 15 |
self.vector_store = vector_store
|
| 16 |
-
self.reranker = CrossEncoder(reranker_model, max_length=512, device=device)
|
| 17 |
-
docs_in_order = list(self.vector_store.docstore._dict.values())
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
self.chunk_texts = [doc.page_content for doc in docs_in_order]
|
| 20 |
self.chunk_metadata = [doc.metadata for doc in docs_in_order]
|
| 21 |
|
| 22 |
-
print(
|
| 23 |
self.vectorizer = TfidfVectorizer()
|
| 24 |
self.tfidf_matrix = self.vectorizer.fit_transform(self.chunk_texts)
|
| 25 |
print("reranker ready")
|
|
@@ -31,28 +34,22 @@ class HybridReranker:
|
|
| 31 |
top_k_final: int = 5,
|
| 32 |
) -> List[Document]:
|
| 33 |
dense_docs = self.vector_store.similarity_search(query, k=top_k_dense)
|
| 34 |
-
|
| 35 |
q_vec = self.vectorizer.transform([query])
|
| 36 |
sparse_scores = (self.tfidf_matrix @ q_vec.T).toarray().ravel()
|
| 37 |
sparse_indices = np.argsort(-sparse_scores)[:top_k_dense]
|
| 38 |
-
|
| 39 |
sparse_docs = [
|
| 40 |
Document(page_content=self.chunk_texts[i], metadata=self.chunk_metadata[i])
|
| 41 |
for i in sparse_indices
|
| 42 |
]
|
| 43 |
-
|
| 44 |
combined_docs = []
|
| 45 |
seen_contents = set()
|
| 46 |
for doc in dense_docs + sparse_docs:
|
| 47 |
if doc.page_content not in seen_contents:
|
| 48 |
combined_docs.append(doc)
|
| 49 |
seen_contents.add(doc.page_content)
|
| 50 |
-
|
| 51 |
pairs = [[query, doc.page_content] for doc in combined_docs]
|
| 52 |
rerank_scores = self.reranker.predict(pairs)
|
| 53 |
-
|
| 54 |
doc_scores = list(zip(combined_docs, rerank_scores))
|
| 55 |
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
|
| 56 |
final_docs = [doc for doc, score in sorted_doc_scores[:top_k_final]]
|
| 57 |
-
|
| 58 |
return final_docs
|
|
|
|
| 10 |
self,
|
| 11 |
vector_store: FAISS,
|
| 12 |
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 13 |
+
device: str = 'cpu',
|
| 14 |
+
cache_dir: str = "/app/huggingface_cache"
|
| 15 |
):
|
| 16 |
self.vector_store = vector_store
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
print(f"loading CrossEncoder. saving in: {cache_dir}")
|
| 19 |
+
self.reranker = CrossEncoder(reranker_model, max_length=512, device=device, cache_folder=cache_dir)
|
| 20 |
+
|
| 21 |
+
docs_in_order = list(self.vector_store.docstore._dict.values())
|
| 22 |
self.chunk_texts = [doc.page_content for doc in docs_in_order]
|
| 23 |
self.chunk_metadata = [doc.metadata for doc in docs_in_order]
|
| 24 |
|
| 25 |
+
print("building tf-idf")
|
| 26 |
self.vectorizer = TfidfVectorizer()
|
| 27 |
self.tfidf_matrix = self.vectorizer.fit_transform(self.chunk_texts)
|
| 28 |
print("reranker ready")
|
|
|
|
| 34 |
top_k_final: int = 5,
|
| 35 |
) -> List[Document]:
|
| 36 |
dense_docs = self.vector_store.similarity_search(query, k=top_k_dense)
|
|
|
|
| 37 |
q_vec = self.vectorizer.transform([query])
|
| 38 |
sparse_scores = (self.tfidf_matrix @ q_vec.T).toarray().ravel()
|
| 39 |
sparse_indices = np.argsort(-sparse_scores)[:top_k_dense]
|
|
|
|
| 40 |
sparse_docs = [
|
| 41 |
Document(page_content=self.chunk_texts[i], metadata=self.chunk_metadata[i])
|
| 42 |
for i in sparse_indices
|
| 43 |
]
|
|
|
|
| 44 |
combined_docs = []
|
| 45 |
seen_contents = set()
|
| 46 |
for doc in dense_docs + sparse_docs:
|
| 47 |
if doc.page_content not in seen_contents:
|
| 48 |
combined_docs.append(doc)
|
| 49 |
seen_contents.add(doc.page_content)
|
|
|
|
| 50 |
pairs = [[query, doc.page_content] for doc in combined_docs]
|
| 51 |
rerank_scores = self.reranker.predict(pairs)
|
|
|
|
| 52 |
doc_scores = list(zip(combined_docs, rerank_scores))
|
| 53 |
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
|
| 54 |
final_docs = [doc for doc, score in sorted_doc_scores[:top_k_final]]
|
|
|
|
| 55 |
return final_docs
|