juliaturc commited on
Commit
7a1bb92
·
1 Parent(s): 4d7acde

Add Cohere, NVIDIA, Jina & HuggingFace rerankers (#37)

Browse files
Files changed (5) hide show
  1. README.md +13 -1
  2. requirements.txt +10 -9
  3. sage/chat.py +3 -25
  4. sage/reranker.py +42 -0
  5. sage/vector_store.py +5 -1
README.md CHANGED
@@ -94,10 +94,22 @@ pip install git+https://github.com/Storia-AI/sage.git@main
94
 
95
  <br>
96
  <summary><strong>Optional</strong></summary>
97
- If you are planning on indexing GitHub issues in addition to the codebase, you will need a GitHub token:
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  export GITHUB_TOKEN=...
100
 
 
101
  ## Running it
102
 
103
  <details open>
 
94
 
95
  <br>
96
  <summary><strong>Optional</strong></summary>
97
+
98
+ - By default, we use an <a href="https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2">open-source re-ranker</a>. For higher accuracy, you can use <a href="https://cohere.com/rerank">Cohere</a>, <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a> or <a href="https://jina.ai/reranker/">Jina</a>:
99
+
100
+ ```
101
+ export COHERE_API_KEY=...
102
+ export NVIDIA_API_KEY=...
103
+ export JINA_API_KEY=...
104
+ ```
105
+
106
+ We are seeing significant gains in accuracy from these proprietary rerankers.
107
+
108
+ - If you are planning on indexing GitHub issues in addition to the codebase, you will need a GitHub token:
109
 
110
  export GITHUB_TOKEN=...
111
 
112
+
113
  ## Running it
114
 
115
  <details open>
requirements.txt CHANGED
@@ -3,15 +3,16 @@ Pygments==2.18.0
3
  cohere==5.9.2
4
  fastapi==0.112.2
5
  gradio>=4.26.0
6
- langchain==0.3.0
7
- langchain-anthropic==0.2.0
8
- langchain-cohere==0.3.0
9
- langchain-community==0.3.0
10
- langchain-core==0.3.0
11
- langchain-experimental==0.3.0
12
- langchain-ollama==0.2.0
13
- langchain-openai==0.2.0
14
- langchain-text-splitters==0.3.0
 
15
  marqo==3.7.0
16
  nbformat==5.10.4
17
  openai==1.42.0
 
3
  cohere==5.9.2
4
  fastapi==0.112.2
5
  gradio>=4.26.0
6
+ langchain==0.2.16
7
+ langchain-anthropic==0.1.23
8
+ langchain-cohere==0.2.4
9
+ langchain-community==0.2.17
10
+ langchain-core==0.2.40
11
+ langchain-experimental==0.0.65
12
+ langchain-nvidia-ai-endpoints==0.2.2
13
+ langchain-ollama==0.1.3
14
+ langchain-openai==0.1.25
15
+ langchain-text-splitters==0.2.4
16
  marqo==3.7.0
17
  nbformat==5.10.4
18
  openai==1.42.0
sage/chat.py CHANGED
@@ -12,14 +12,12 @@ from dotenv import load_dotenv
12
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
13
  from langchain.chains.combine_documents import create_stuff_documents_chain
14
  from langchain.retrievers import ContextualCompressionRetriever
15
- from langchain.retrievers.document_compressors import CrossEncoderReranker
16
  from langchain.schema import AIMessage, HumanMessage
17
- from langchain_cohere import CohereRerank
18
- from langchain_community.cross_encoders import HuggingFaceCrossEncoder
19
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
20
 
21
  import sage.vector_store as vector_store
22
  from sage.llm import build_llm_via_langchain
 
23
 
24
  load_dotenv()
25
 
@@ -30,15 +28,7 @@ def build_rag_chain(args):
30
 
31
  retriever_top_k = 5 if args.reranker_provider == "none" else 25
32
  retriever = vector_store.build_from_args(args).as_retriever(top_k=retriever_top_k)
33
-
34
- if args.reranker_provider == "none":
35
- compressor = None
36
- if args.reranker_provider == "huggingface":
37
- encoder_model = HuggingFaceCrossEncoder(model_name=args.reranker_model)
38
- compressor = CrossEncoderReranker(model=encoder_model, top_n=5)
39
- if args.reranker_provider == "cohere":
40
- compressor = CohereRerank(model=args.reranker_model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=5)
41
-
42
  if compressor:
43
  retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
44
 
@@ -94,7 +84,7 @@ def main():
94
  default="http://localhost:8882",
95
  help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
96
  )
97
- parser.add_argument("--reranker-provider", default="huggingface", choices=["none", "huggingface", "cohere"])
98
  parser.add_argument(
99
  "--reranker-model",
100
  help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
@@ -114,12 +104,6 @@ def main():
114
  )
115
  args = parser.parse_args()
116
 
117
- if not args.index_name:
118
- if args.vector_store_type == "marqo":
119
- args.index_name = args.repo_id.split("/")[1]
120
- elif args.vector_store_type == "pinecone":
121
- parser.error("Please specify --index-name for Pinecone.")
122
-
123
  if not args.llm_model:
124
  if args.llm_provider == "openai":
125
  args.llm_model = "gpt-4"
@@ -130,12 +114,6 @@ def main():
130
  else:
131
  raise ValueError("Please specify --llm_model")
132
 
133
- if not args.reranker_model:
134
- if args.reranker_provider == "cohere":
135
- args.reranker_model = "rerank-english-v3.0"
136
- elif args.reranker_provider == "huggingface":
137
- args.reranker_model = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
138
-
139
  rag_chain = build_rag_chain(args)
140
 
141
  def source_md(file_path: str, url: str) -> str:
 
12
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
13
  from langchain.chains.combine_documents import create_stuff_documents_chain
14
  from langchain.retrievers import ContextualCompressionRetriever
 
15
  from langchain.schema import AIMessage, HumanMessage
 
 
16
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
 
18
  import sage.vector_store as vector_store
19
  from sage.llm import build_llm_via_langchain
20
+ from sage.reranker import build_reranker, RerankerProvider
21
 
22
  load_dotenv()
23
 
 
28
 
29
  retriever_top_k = 5 if args.reranker_provider == "none" else 25
30
  retriever = vector_store.build_from_args(args).as_retriever(top_k=retriever_top_k)
31
+ compressor = build_reranker(args.reranker_provider, args.reranker_model)
 
 
 
 
 
 
 
 
32
  if compressor:
33
  retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
34
 
 
84
  default="http://localhost:8882",
85
  help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
86
  )
87
+ parser.add_argument("--reranker-provider", default="huggingface", choices=[r.value for r in RerankerProvider])
88
  parser.add_argument(
89
  "--reranker-model",
90
  help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
 
104
  )
105
  args = parser.parse_args()
106
 
 
 
 
 
 
 
107
  if not args.llm_model:
108
  if args.llm_provider == "openai":
109
  args.llm_model = "gpt-4"
 
114
  else:
115
  raise ValueError("Please specify --llm_model")
116
 
 
 
 
 
 
 
117
  rag_chain = build_rag_chain(args)
118
 
119
  def source_md(file_path: str, url: str) -> str:
sage/reranker.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import Enum
3
+ from typing import Optional
4
+
5
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
6
+ from langchain_cohere import CohereRerank
7
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
8
+ from langchain_community.document_compressors import JinaRerank
9
+ from langchain_core.documents import BaseDocumentCompressor
10
+ from langchain_nvidia_ai_endpoints import NVIDIARerank
11
+
12
+
13
+ class RerankerProvider(Enum):
14
+ NONE = "none"
15
+ HUGGINGFACE = "huggingface"
16
+ COHERE = "cohere"
17
+ NVIDIA = "nvidia"
18
+ JINA = "jina"
19
+
20
+
21
+ def build_reranker(provider: str, model: Optional[str] = None, top_n: Optional[int] = 5) -> BaseDocumentCompressor:
22
+ if provider == RerankerProvider.NONE.value:
23
+ return None
24
+ if provider == RerankerProvider.HUGGINGFACE.value:
25
+ model = model or "cross-encoder/ms-marco-MiniLM-L-6-v2"
26
+ encoder_model = HuggingFaceCrossEncoder(model_name=model)
27
+ return CrossEncoderReranker(model=encoder_model, top_n=top_n)
28
+ if provider == RerankerProvider.COHERE.value:
29
+ if not os.environ.get("COHERE_API_KEY"):
30
+ raise ValueError("Please set the COHERE_API_KEY environment variable")
31
+ model = model or "rerank-english-v3.0"
32
+ return CohereRerank(model=model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=top_n)
33
+ if provider == RerankerProvider.NVIDIA.value:
34
+ if not os.environ.get("NVIDIA_API_KEY"):
35
+ raise ValueError("Please set the NVIDIA_API_KEY environment variable")
36
+ model = model or "nvidia/nv-rerankqa-mistral-4b-v3"
37
+ return NVIDIARerank(model=model, api_key=os.environ.get("NVIDIA_API_KEY"), top_n=top_n, truncate="END")
38
+ if provider == RerankerProvider.JINA.value:
39
+ if not os.environ.get("JINA_API_KEY"):
40
+ raise ValueError("Please set the JINA_API_KEY environment variable")
41
+ return JinaRerank(top_n=top_n)
42
+ raise ValueError(f"Invalid reranker provider: {provider}")
sage/vector_store.py CHANGED
@@ -149,11 +149,15 @@ class MarqoVectorStore(VectorStore):
149
  def build_from_args(args: dict) -> VectorStore:
150
  """Builds a vector store from the given command-line arguments."""
151
  if args.vector_store_type == "pinecone":
 
 
152
  dimension = args.embedding_size if "embedding_size" in args else None
153
  return PineconeVectorStore(
154
  index_name=args.index_name, namespace=args.repo_id, dimension=dimension, hybrid=args.hybrid_retrieval
155
  )
156
  elif args.vector_store_type == "marqo":
157
- return MarqoVectorStore(url=args.marqo_url, index_name=args.index_name)
 
 
158
  else:
159
  raise ValueError(f"Unrecognized vector store type {args.vector_store_type}")
 
149
  def build_from_args(args: dict) -> VectorStore:
150
  """Builds a vector store from the given command-line arguments."""
151
  if args.vector_store_type == "pinecone":
152
+ if not args.index_name:
153
+ raise ValueError("Please specify --index-name for Pinecone.")
154
  dimension = args.embedding_size if "embedding_size" in args else None
155
  return PineconeVectorStore(
156
  index_name=args.index_name, namespace=args.repo_id, dimension=dimension, hybrid=args.hybrid_retrieval
157
  )
158
  elif args.vector_store_type == "marqo":
159
+ marqo_url = args.marqo_url or "http://localhost:8882"
160
+ index_name = args.index_name or args.repo_id.split("/")[1]
161
+ return MarqoVectorStore(url=marqo_url, index_name=index_name)
162
  else:
163
  raise ValueError(f"Unrecognized vector store type {args.vector_store_type}")