Spaces:
Running
Running
aarya-16 commited on
Commit ·
201ecae
1
Parent(s): 44a3094
Implemented Reciprocal Rank Fusion (#87)
Browse files* Implemented Ensemble Retriever
* Used self.alpha for the weights
* Cleaned up the return statement
- sage/vector_store.py +17 -12
sage/vector_store.py
CHANGED
|
@@ -8,7 +8,8 @@ from typing import Dict, Generator, List, Optional, Tuple
|
|
| 8 |
|
| 9 |
import marqo
|
| 10 |
import nltk
|
| 11 |
-
from langchain_community.retrievers import
|
|
|
|
| 12 |
from langchain_community.vectorstores import Marqo
|
| 13 |
from langchain_community.vectorstores import Pinecone as LangChainPinecone
|
| 14 |
from langchain_core.documents import Document
|
|
@@ -133,19 +134,23 @@ class PineconeVectorStore(VectorStore):
|
|
| 133 |
self.index.upsert(vectors=pinecone_vectors, namespace=namespace)
|
| 134 |
|
| 135 |
def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
return LangChainPinecone.from_existing_index(
|
| 147 |
index_name=self.index_name, embedding=embeddings, namespace=namespace
|
| 148 |
).as_retriever(search_kwargs={"k": top_k})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
|
| 151 |
class MarqoVectorStore(VectorStore):
|
|
|
|
| 8 |
|
| 9 |
import marqo
|
| 10 |
import nltk
|
| 11 |
+
from langchain_community.retrievers import BM25Retriever
|
| 12 |
+
from langchain.retrievers import EnsembleRetriever
|
| 13 |
from langchain_community.vectorstores import Marqo
|
| 14 |
from langchain_community.vectorstores import Pinecone as LangChainPinecone
|
| 15 |
from langchain_core.documents import Document
|
|
|
|
| 134 |
self.index.upsert(vectors=pinecone_vectors, namespace=namespace)
|
| 135 |
|
| 136 |
def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
|
| 137 |
+
bm25_retriever = BM25Retriever(
|
| 138 |
+
embeddings=embeddings,
|
| 139 |
+
sparse_encoder=self.bm25_encoder,
|
| 140 |
+
index=self.index,
|
| 141 |
+
namespace=namespace,
|
| 142 |
+
top_k=top_k,
|
| 143 |
+
) if self.bm25_encoder else None
|
| 144 |
+
|
| 145 |
+
dense_retriever = LangChainPinecone.from_existing_index(
|
|
|
|
|
|
|
| 146 |
index_name=self.index_name, embedding=embeddings, namespace=namespace
|
| 147 |
).as_retriever(search_kwargs={"k": top_k})
|
| 148 |
+
|
| 149 |
+
if bm25_retriever:
|
| 150 |
+
return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1-self.alpha])
|
| 151 |
+
else:
|
| 152 |
+
return dense_retriever
|
| 153 |
+
|
| 154 |
|
| 155 |
|
| 156 |
class MarqoVectorStore(VectorStore):
|