Spaces:
Running
Running
Fix formatting
Browse files- sage/reranker.py +2 -2
- sage/vector_store.py +14 -11
sage/reranker.py
CHANGED
|
@@ -28,14 +28,14 @@ def build_reranker(provider: str, model: Optional[str] = None, top_k: int = 5) -
|
|
| 28 |
RerankerProvider.COHERE.value: "COHERE_API_KEY",
|
| 29 |
RerankerProvider.NVIDIA.value: "NVIDIA_API_KEY",
|
| 30 |
RerankerProvider.JINA.value: "JINA_API_KEY",
|
| 31 |
-
RerankerProvider.VOYAGE.value: "VOYAGE_API_KEY"
|
| 32 |
}
|
| 33 |
|
| 34 |
provider_defaults = {
|
| 35 |
RerankerProvider.HUGGINGFACE.value: "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 36 |
RerankerProvider.COHERE.value: "rerank-english-v3.0",
|
| 37 |
RerankerProvider.NVIDIA.value: "nvidia/nv-rerankqa-mistral-4b-v3",
|
| 38 |
-
RerankerProvider.VOYAGE.value: "rerank-1"
|
| 39 |
}
|
| 40 |
|
| 41 |
model = model or provider_defaults.get(provider)
|
|
|
|
| 28 |
RerankerProvider.COHERE.value: "COHERE_API_KEY",
|
| 29 |
RerankerProvider.NVIDIA.value: "NVIDIA_API_KEY",
|
| 30 |
RerankerProvider.JINA.value: "JINA_API_KEY",
|
| 31 |
+
RerankerProvider.VOYAGE.value: "VOYAGE_API_KEY",
|
| 32 |
}
|
| 33 |
|
| 34 |
provider_defaults = {
|
| 35 |
RerankerProvider.HUGGINGFACE.value: "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 36 |
RerankerProvider.COHERE.value: "rerank-english-v3.0",
|
| 37 |
RerankerProvider.NVIDIA.value: "nvidia/nv-rerankqa-mistral-4b-v3",
|
| 38 |
+
RerankerProvider.VOYAGE.value: "rerank-1",
|
| 39 |
}
|
| 40 |
|
| 41 |
model = model or provider_defaults.get(provider)
|
sage/vector_store.py
CHANGED
|
@@ -8,8 +8,8 @@ from typing import Dict, Generator, List, Optional, Tuple
|
|
| 8 |
|
| 9 |
import marqo
|
| 10 |
import nltk
|
| 11 |
-
from langchain_community.retrievers import BM25Retriever
|
| 12 |
from langchain.retrievers import EnsembleRetriever
|
|
|
|
| 13 |
from langchain_community.vectorstores import Marqo
|
| 14 |
from langchain_community.vectorstores import Pinecone as LangChainPinecone
|
| 15 |
from langchain_core.documents import Document
|
|
@@ -134,25 +134,28 @@ class PineconeVectorStore(VectorStore):
|
|
| 134 |
self.index.upsert(vectors=pinecone_vectors, namespace=namespace)
|
| 135 |
|
| 136 |
def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
|
| 137 |
-
bm25_retriever =
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
dense_retriever = LangChainPinecone.from_existing_index(
|
| 146 |
index_name=self.index_name, embedding=embeddings, namespace=namespace
|
| 147 |
).as_retriever(search_kwargs={"k": top_k})
|
| 148 |
-
|
| 149 |
if bm25_retriever:
|
| 150 |
-
return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1-self.alpha])
|
| 151 |
else:
|
| 152 |
return dense_retriever
|
| 153 |
|
| 154 |
|
| 155 |
-
|
| 156 |
class MarqoVectorStore(VectorStore):
|
| 157 |
"""Vector store implementation using Marqo."""
|
| 158 |
|
|
|
|
| 8 |
|
| 9 |
import marqo
|
| 10 |
import nltk
|
|
|
|
| 11 |
from langchain.retrievers import EnsembleRetriever
|
| 12 |
+
from langchain_community.retrievers import BM25Retriever
|
| 13 |
from langchain_community.vectorstores import Marqo
|
| 14 |
from langchain_community.vectorstores import Pinecone as LangChainPinecone
|
| 15 |
from langchain_core.documents import Document
|
|
|
|
| 134 |
self.index.upsert(vectors=pinecone_vectors, namespace=namespace)
|
| 135 |
|
| 136 |
def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
|
| 137 |
+
bm25_retriever = (
|
| 138 |
+
BM25Retriever(
|
| 139 |
+
embeddings=embeddings,
|
| 140 |
+
sparse_encoder=self.bm25_encoder,
|
| 141 |
+
index=self.index,
|
| 142 |
+
namespace=namespace,
|
| 143 |
+
top_k=top_k,
|
| 144 |
+
)
|
| 145 |
+
if self.bm25_encoder
|
| 146 |
+
else None
|
| 147 |
+
)
|
| 148 |
|
| 149 |
dense_retriever = LangChainPinecone.from_existing_index(
|
| 150 |
index_name=self.index_name, embedding=embeddings, namespace=namespace
|
| 151 |
).as_retriever(search_kwargs={"k": top_k})
|
| 152 |
+
|
| 153 |
if bm25_retriever:
|
| 154 |
+
return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1 - self.alpha])
|
| 155 |
else:
|
| 156 |
return dense_retriever
|
| 157 |
|
| 158 |
|
|
|
|
| 159 |
class MarqoVectorStore(VectorStore):
|
| 160 |
"""Vector store implementation using Marqo."""
|
| 161 |
|