juliaturc commited on
Commit
2fdd54a
·
1 Parent(s): 3f12090

Fix formatting

Browse files
Files changed (2) hide show
  1. sage/reranker.py +2 -2
  2. sage/vector_store.py +14 -11
sage/reranker.py CHANGED
@@ -28,14 +28,14 @@ def build_reranker(provider: str, model: Optional[str] = None, top_k: int = 5) -
28
  RerankerProvider.COHERE.value: "COHERE_API_KEY",
29
  RerankerProvider.NVIDIA.value: "NVIDIA_API_KEY",
30
  RerankerProvider.JINA.value: "JINA_API_KEY",
31
- RerankerProvider.VOYAGE.value: "VOYAGE_API_KEY"
32
  }
33
 
34
  provider_defaults = {
35
  RerankerProvider.HUGGINGFACE.value: "cross-encoder/ms-marco-MiniLM-L-6-v2",
36
  RerankerProvider.COHERE.value: "rerank-english-v3.0",
37
  RerankerProvider.NVIDIA.value: "nvidia/nv-rerankqa-mistral-4b-v3",
38
- RerankerProvider.VOYAGE.value: "rerank-1"
39
  }
40
 
41
  model = model or provider_defaults.get(provider)
 
28
  RerankerProvider.COHERE.value: "COHERE_API_KEY",
29
  RerankerProvider.NVIDIA.value: "NVIDIA_API_KEY",
30
  RerankerProvider.JINA.value: "JINA_API_KEY",
31
+ RerankerProvider.VOYAGE.value: "VOYAGE_API_KEY",
32
  }
33
 
34
  provider_defaults = {
35
  RerankerProvider.HUGGINGFACE.value: "cross-encoder/ms-marco-MiniLM-L-6-v2",
36
  RerankerProvider.COHERE.value: "rerank-english-v3.0",
37
  RerankerProvider.NVIDIA.value: "nvidia/nv-rerankqa-mistral-4b-v3",
38
+ RerankerProvider.VOYAGE.value: "rerank-1",
39
  }
40
 
41
  model = model or provider_defaults.get(provider)
sage/vector_store.py CHANGED
@@ -8,8 +8,8 @@ from typing import Dict, Generator, List, Optional, Tuple
8
 
9
  import marqo
10
  import nltk
11
- from langchain_community.retrievers import BM25Retriever
12
  from langchain.retrievers import EnsembleRetriever
 
13
  from langchain_community.vectorstores import Marqo
14
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
15
  from langchain_core.documents import Document
@@ -134,25 +134,28 @@ class PineconeVectorStore(VectorStore):
134
  self.index.upsert(vectors=pinecone_vectors, namespace=namespace)
135
 
136
  def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
137
- bm25_retriever = BM25Retriever(
138
- embeddings=embeddings,
139
- sparse_encoder=self.bm25_encoder,
140
- index=self.index,
141
- namespace=namespace,
142
- top_k=top_k,
143
- ) if self.bm25_encoder else None
 
 
 
 
144
 
145
  dense_retriever = LangChainPinecone.from_existing_index(
146
  index_name=self.index_name, embedding=embeddings, namespace=namespace
147
  ).as_retriever(search_kwargs={"k": top_k})
148
-
149
  if bm25_retriever:
150
- return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1-self.alpha])
151
  else:
152
  return dense_retriever
153
 
154
 
155
-
156
  class MarqoVectorStore(VectorStore):
157
  """Vector store implementation using Marqo."""
158
 
 
8
 
9
  import marqo
10
  import nltk
 
11
  from langchain.retrievers import EnsembleRetriever
12
+ from langchain_community.retrievers import BM25Retriever
13
  from langchain_community.vectorstores import Marqo
14
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
15
  from langchain_core.documents import Document
 
134
  self.index.upsert(vectors=pinecone_vectors, namespace=namespace)
135
 
136
  def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
137
+ bm25_retriever = (
138
+ BM25Retriever(
139
+ embeddings=embeddings,
140
+ sparse_encoder=self.bm25_encoder,
141
+ index=self.index,
142
+ namespace=namespace,
143
+ top_k=top_k,
144
+ )
145
+ if self.bm25_encoder
146
+ else None
147
+ )
148
 
149
  dense_retriever = LangChainPinecone.from_existing_index(
150
  index_name=self.index_name, embedding=embeddings, namespace=namespace
151
  ).as_retriever(search_kwargs={"k": top_k})
152
+
153
  if bm25_retriever:
154
+ return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1 - self.alpha])
155
  else:
156
  return dense_retriever
157
 
158
 
 
159
  class MarqoVectorStore(VectorStore):
160
  """Vector store implementation using Marqo."""
161