Spaces:
Running
Running
Add Cohere, NVIDIA, Jina & HuggingFace rerankers (#37)
Browse files- README.md +13 -1
- requirements.txt +10 -9
- sage/chat.py +3 -25
- sage/reranker.py +42 -0
- sage/vector_store.py +5 -1
README.md
CHANGED
|
@@ -94,10 +94,22 @@ pip install git+https://github.com/Storia-AI/sage.git@main
|
|
| 94 |
|
| 95 |
<br>
|
| 96 |
<summary><strong>Optional</strong></summary>
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
export GITHUB_TOKEN=...
|
| 100 |
|
|
|
|
| 101 |
## Running it
|
| 102 |
|
| 103 |
<details open>
|
|
|
|
| 94 |
|
| 95 |
<br>
|
| 96 |
<summary><strong>Optional</strong></summary>
|
| 97 |
+
|
| 98 |
+
- By default, we use an <a href="https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2">open-source re-ranker</a>. For higher accuracy, you can use <a href="https://cohere.com/rerank">Cohere</a>, <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a> or <a href="https://jina.ai/reranker/">Jina</a>:
|
| 99 |
+
|
| 100 |
+
```
|
| 101 |
+
export COHERE_API_KEY=...
|
| 102 |
+
export NVIDIA_API_KEY=...
|
| 103 |
+
export JINA_API_KEY=...
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
We are seeing significant gains in accuracy from these proprietary rerankers.
|
| 107 |
+
|
| 108 |
+
- If you are planning on indexing GitHub issues in addition to the codebase, you will need a GitHub token:
|
| 109 |
|
| 110 |
export GITHUB_TOKEN=...
|
| 111 |
|
| 112 |
+
|
| 113 |
## Running it
|
| 114 |
|
| 115 |
<details open>
|
requirements.txt
CHANGED
|
@@ -3,15 +3,16 @@ Pygments==2.18.0
|
|
| 3 |
cohere==5.9.2
|
| 4 |
fastapi==0.112.2
|
| 5 |
gradio>=4.26.0
|
| 6 |
-
langchain==0.
|
| 7 |
-
langchain-anthropic==0.
|
| 8 |
-
langchain-cohere==0.
|
| 9 |
-
langchain-community==0.
|
| 10 |
-
langchain-core==0.
|
| 11 |
-
langchain-experimental==0.
|
| 12 |
-
langchain-
|
| 13 |
-
langchain-
|
| 14 |
-
langchain-
|
|
|
|
| 15 |
marqo==3.7.0
|
| 16 |
nbformat==5.10.4
|
| 17 |
openai==1.42.0
|
|
|
|
| 3 |
cohere==5.9.2
|
| 4 |
fastapi==0.112.2
|
| 5 |
gradio>=4.26.0
|
| 6 |
+
langchain==0.2.16
|
| 7 |
+
langchain-anthropic==0.1.23
|
| 8 |
+
langchain-cohere==0.2.4
|
| 9 |
+
langchain-community==0.2.17
|
| 10 |
+
langchain-core==0.2.40
|
| 11 |
+
langchain-experimental==0.0.65
|
| 12 |
+
langchain-nvidia-ai-endpoints==0.2.2
|
| 13 |
+
langchain-ollama==0.1.3
|
| 14 |
+
langchain-openai==0.1.25
|
| 15 |
+
langchain-text-splitters==0.2.4
|
| 16 |
marqo==3.7.0
|
| 17 |
nbformat==5.10.4
|
| 18 |
openai==1.42.0
|
sage/chat.py
CHANGED
|
@@ -12,14 +12,12 @@ from dotenv import load_dotenv
|
|
| 12 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
| 13 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 14 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 15 |
-
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
| 16 |
from langchain.schema import AIMessage, HumanMessage
|
| 17 |
-
from langchain_cohere import CohereRerank
|
| 18 |
-
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 19 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 20 |
|
| 21 |
import sage.vector_store as vector_store
|
| 22 |
from sage.llm import build_llm_via_langchain
|
|
|
|
| 23 |
|
| 24 |
load_dotenv()
|
| 25 |
|
|
@@ -30,15 +28,7 @@ def build_rag_chain(args):
|
|
| 30 |
|
| 31 |
retriever_top_k = 5 if args.reranker_provider == "none" else 25
|
| 32 |
retriever = vector_store.build_from_args(args).as_retriever(top_k=retriever_top_k)
|
| 33 |
-
|
| 34 |
-
if args.reranker_provider == "none":
|
| 35 |
-
compressor = None
|
| 36 |
-
if args.reranker_provider == "huggingface":
|
| 37 |
-
encoder_model = HuggingFaceCrossEncoder(model_name=args.reranker_model)
|
| 38 |
-
compressor = CrossEncoderReranker(model=encoder_model, top_n=5)
|
| 39 |
-
if args.reranker_provider == "cohere":
|
| 40 |
-
compressor = CohereRerank(model=args.reranker_model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=5)
|
| 41 |
-
|
| 42 |
if compressor:
|
| 43 |
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
| 44 |
|
|
@@ -94,7 +84,7 @@ def main():
|
|
| 94 |
default="http://localhost:8882",
|
| 95 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 96 |
)
|
| 97 |
-
parser.add_argument("--reranker-provider", default="huggingface", choices=[
|
| 98 |
parser.add_argument(
|
| 99 |
"--reranker-model",
|
| 100 |
help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
|
|
@@ -114,12 +104,6 @@ def main():
|
|
| 114 |
)
|
| 115 |
args = parser.parse_args()
|
| 116 |
|
| 117 |
-
if not args.index_name:
|
| 118 |
-
if args.vector_store_type == "marqo":
|
| 119 |
-
args.index_name = args.repo_id.split("/")[1]
|
| 120 |
-
elif args.vector_store_type == "pinecone":
|
| 121 |
-
parser.error("Please specify --index-name for Pinecone.")
|
| 122 |
-
|
| 123 |
if not args.llm_model:
|
| 124 |
if args.llm_provider == "openai":
|
| 125 |
args.llm_model = "gpt-4"
|
|
@@ -130,12 +114,6 @@ def main():
|
|
| 130 |
else:
|
| 131 |
raise ValueError("Please specify --llm_model")
|
| 132 |
|
| 133 |
-
if not args.reranker_model:
|
| 134 |
-
if args.reranker_provider == "cohere":
|
| 135 |
-
args.reranker_model = "rerank-english-v3.0"
|
| 136 |
-
elif args.reranker_provider == "huggingface":
|
| 137 |
-
args.reranker_model = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
|
| 138 |
-
|
| 139 |
rag_chain = build_rag_chain(args)
|
| 140 |
|
| 141 |
def source_md(file_path: str, url: str) -> str:
|
|
|
|
| 12 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
| 13 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 14 |
from langchain.retrievers import ContextualCompressionRetriever
|
|
|
|
| 15 |
from langchain.schema import AIMessage, HumanMessage
|
|
|
|
|
|
|
| 16 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 17 |
|
| 18 |
import sage.vector_store as vector_store
|
| 19 |
from sage.llm import build_llm_via_langchain
|
| 20 |
+
from sage.reranker import build_reranker, RerankerProvider
|
| 21 |
|
| 22 |
load_dotenv()
|
| 23 |
|
|
|
|
| 28 |
|
| 29 |
retriever_top_k = 5 if args.reranker_provider == "none" else 25
|
| 30 |
retriever = vector_store.build_from_args(args).as_retriever(top_k=retriever_top_k)
|
| 31 |
+
compressor = build_reranker(args.reranker_provider, args.reranker_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
if compressor:
|
| 33 |
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
| 34 |
|
|
|
|
| 84 |
default="http://localhost:8882",
|
| 85 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 86 |
)
|
| 87 |
+
parser.add_argument("--reranker-provider", default="huggingface", choices=[r.value for r in RerankerProvider])
|
| 88 |
parser.add_argument(
|
| 89 |
"--reranker-model",
|
| 90 |
help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
|
|
|
|
| 104 |
)
|
| 105 |
args = parser.parse_args()
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
if not args.llm_model:
|
| 108 |
if args.llm_provider == "openai":
|
| 109 |
args.llm_model = "gpt-4"
|
|
|
|
| 114 |
else:
|
| 115 |
raise ValueError("Please specify --llm_model")
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
rag_chain = build_rag_chain(args)
|
| 118 |
|
| 119 |
def source_md(file_path: str, url: str) -> str:
|
sage/reranker.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
| 6 |
+
from langchain_cohere import CohereRerank
|
| 7 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 8 |
+
from langchain_community.document_compressors import JinaRerank
|
| 9 |
+
from langchain_core.documents import BaseDocumentCompressor
|
| 10 |
+
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RerankerProvider(Enum):
|
| 14 |
+
NONE = "none"
|
| 15 |
+
HUGGINGFACE = "huggingface"
|
| 16 |
+
COHERE = "cohere"
|
| 17 |
+
NVIDIA = "nvidia"
|
| 18 |
+
JINA = "jina"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_reranker(provider: str, model: Optional[str] = None, top_n: Optional[int] = 5) -> BaseDocumentCompressor:
|
| 22 |
+
if provider == RerankerProvider.NONE.value:
|
| 23 |
+
return None
|
| 24 |
+
if provider == RerankerProvider.HUGGINGFACE.value:
|
| 25 |
+
model = model or "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 26 |
+
encoder_model = HuggingFaceCrossEncoder(model_name=model)
|
| 27 |
+
return CrossEncoderReranker(model=encoder_model, top_n=top_n)
|
| 28 |
+
if provider == RerankerProvider.COHERE.value:
|
| 29 |
+
if not os.environ.get("COHERE_API_KEY"):
|
| 30 |
+
raise ValueError("Please set the COHERE_API_KEY environment variable")
|
| 31 |
+
model = model or "rerank-english-v3.0"
|
| 32 |
+
return CohereRerank(model=model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=top_n)
|
| 33 |
+
if provider == RerankerProvider.NVIDIA.value:
|
| 34 |
+
if not os.environ.get("NVIDIA_API_KEY"):
|
| 35 |
+
raise ValueError("Please set the NVIDIA_API_KEY environment variable")
|
| 36 |
+
model = model or "nvidia/nv-rerankqa-mistral-4b-v3"
|
| 37 |
+
return NVIDIARerank(model=model, api_key=os.environ.get("NVIDIA_API_KEY"), top_n=top_n, truncate="END")
|
| 38 |
+
if provider == RerankerProvider.JINA.value:
|
| 39 |
+
if not os.environ.get("JINA_API_KEY"):
|
| 40 |
+
raise ValueError("Please set the JINA_API_KEY environment variable")
|
| 41 |
+
return JinaRerank(top_n=top_n)
|
| 42 |
+
raise ValueError(f"Invalid reranker provider: {provider}")
|
sage/vector_store.py
CHANGED
|
@@ -149,11 +149,15 @@ class MarqoVectorStore(VectorStore):
|
|
| 149 |
def build_from_args(args: dict) -> VectorStore:
|
| 150 |
"""Builds a vector store from the given command-line arguments."""
|
| 151 |
if args.vector_store_type == "pinecone":
|
|
|
|
|
|
|
| 152 |
dimension = args.embedding_size if "embedding_size" in args else None
|
| 153 |
return PineconeVectorStore(
|
| 154 |
index_name=args.index_name, namespace=args.repo_id, dimension=dimension, hybrid=args.hybrid_retrieval
|
| 155 |
)
|
| 156 |
elif args.vector_store_type == "marqo":
|
| 157 |
-
|
|
|
|
|
|
|
| 158 |
else:
|
| 159 |
raise ValueError(f"Unrecognized vector store type {args.vector_store_type}")
|
|
|
|
| 149 |
def build_from_args(args: dict) -> VectorStore:
|
| 150 |
"""Builds a vector store from the given command-line arguments."""
|
| 151 |
if args.vector_store_type == "pinecone":
|
| 152 |
+
if not args.index_name:
|
| 153 |
+
raise ValueError("Please specify --index-name for Pinecone.")
|
| 154 |
dimension = args.embedding_size if "embedding_size" in args else None
|
| 155 |
return PineconeVectorStore(
|
| 156 |
index_name=args.index_name, namespace=args.repo_id, dimension=dimension, hybrid=args.hybrid_retrieval
|
| 157 |
)
|
| 158 |
elif args.vector_store_type == "marqo":
|
| 159 |
+
marqo_url = args.marqo_url or "http://localhost:8882"
|
| 160 |
+
index_name = args.index_name or args.repo_id.split("/")[1]
|
| 161 |
+
return MarqoVectorStore(url=marqo_url, index_name=index_name)
|
| 162 |
else:
|
| 163 |
raise ValueError(f"Unrecognized vector store type {args.vector_store_type}")
|