aarya-16 commited on
Commit
201ecae
·
1 Parent(s): 44a3094

Implemented Reciprocal Rank Fusion (#87)

Browse files

* Implemented Ensemble Retriever

* Used self.alpha for the weights

* Cleaned up the return statement

Files changed (1) hide show
  1. sage/vector_store.py +17 -12
sage/vector_store.py CHANGED
@@ -8,7 +8,8 @@ from typing import Dict, Generator, List, Optional, Tuple
8
 
9
  import marqo
10
  import nltk
11
- from langchain_community.retrievers import PineconeHybridSearchRetriever
 
12
  from langchain_community.vectorstores import Marqo
13
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
14
  from langchain_core.documents import Document
@@ -133,19 +134,23 @@ class PineconeVectorStore(VectorStore):
133
  self.index.upsert(vectors=pinecone_vectors, namespace=namespace)
134
 
135
  def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
136
- if self.bm25_encoder:
137
- return PineconeHybridSearchRetriever(
138
- embeddings=embeddings,
139
- sparse_encoder=self.bm25_encoder,
140
- index=self.index,
141
- namespace=namespace,
142
- top_k=top_k,
143
- alpha=self.alpha,
144
- )
145
-
146
- return LangChainPinecone.from_existing_index(
147
  index_name=self.index_name, embedding=embeddings, namespace=namespace
148
  ).as_retriever(search_kwargs={"k": top_k})
 
 
 
 
 
 
149
 
150
 
151
  class MarqoVectorStore(VectorStore):
 
8
 
9
  import marqo
10
  import nltk
11
+ from langchain_community.retrievers import BM25Retriever
12
+ from langchain.retrievers import EnsembleRetriever
13
  from langchain_community.vectorstores import Marqo
14
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
15
  from langchain_core.documents import Document
 
134
  self.index.upsert(vectors=pinecone_vectors, namespace=namespace)
135
 
136
  def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
137
+ bm25_retriever = BM25Retriever(
138
+ embeddings=embeddings,
139
+ sparse_encoder=self.bm25_encoder,
140
+ index=self.index,
141
+ namespace=namespace,
142
+ top_k=top_k,
143
+ ) if self.bm25_encoder else None
144
+
145
+ dense_retriever = LangChainPinecone.from_existing_index(
 
 
146
  index_name=self.index_name, embedding=embeddings, namespace=namespace
147
  ).as_retriever(search_kwargs={"k": top_k})
148
+
149
+ if bm25_retriever:
150
+ return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1-self.alpha])
151
+ else:
152
+ return dense_retriever
153
+
154
 
155
 
156
  class MarqoVectorStore(VectorStore):