juliaturc commited on
Commit
3d780b8
·
1 Parent(s): 52c1352

Add Cohere reranker

Browse files
Files changed (2) hide show
  1. requirements.txt +7 -2
  2. sage/chat.py +9 -1
requirements.txt CHANGED
@@ -1,12 +1,17 @@
1
  GitPython==3.1.43
2
  Pygments==2.18.0
 
3
  fastapi==0.112.2
4
  gradio>=4.26.0
5
- langchain==0.2.14
6
- langchain-community==0.2.12
7
  langchain-anthropic==0.1.23
 
 
 
 
8
  langchain-ollama==0.1.2
9
  langchain-openai==0.1.22
 
10
  marqo==3.7.0
11
  nbformat==5.10.4
12
  openai==1.42.0
 
1
  GitPython==3.1.43
2
  Pygments==2.18.0
3
+ cohere==5.9.2
4
  fastapi==0.112.2
5
  gradio>=4.26.0
6
+ langchain==0.2.15
 
7
  langchain-anthropic==0.1.23
8
+ langchain-cohere==0.2.4
9
+ langchain-community==0.2.12
10
+ langchain-core==0.2.36
11
+ langchain-experimental==0.0.64
12
  langchain-ollama==0.1.2
13
  langchain-openai==0.1.22
14
+ langchain-text-splitters==0.2.2
15
  marqo==3.7.0
16
  nbformat==5.10.4
17
  openai==1.42.0
sage/chat.py CHANGED
@@ -1,9 +1,10 @@
1
  """A gradio app that enables users to chat with their codebase.
2
 
3
- You must run main.py first in order to index the codebase into a vector store.
4
  """
5
 
6
  import argparse
 
7
 
8
  import gradio as gr
9
  from dotenv import load_dotenv
@@ -11,6 +12,8 @@ from langchain.chains import create_history_aware_retriever, create_retrieval_ch
11
  from langchain.chains.combine_documents import create_stuff_documents_chain
12
  from langchain.schema import AIMessage, HumanMessage
13
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
 
 
14
 
15
  import sage.vector_store as vector_store
16
  from sage.llm import build_llm_via_langchain
@@ -21,7 +24,11 @@ load_dotenv()
21
  def build_rag_chain(args):
22
  """Builds a RAG chain via LangChain."""
23
  llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
 
24
  retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
 
 
 
25
 
26
  # Prompt to contextualize the latest query based on the chat history.
27
  contextualize_q_system_prompt = (
@@ -82,6 +89,7 @@ def main():
82
  default="http://localhost:8882",
83
  help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
84
  )
 
85
  parser.add_argument(
86
  "--share",
87
  default=False,
 
1
  """A gradio app that enables users to chat with their codebase.
2
 
3
+ You must run `sage-index $GITHUB_REPO` first in order to index the codebase into a vector store.
4
  """
5
 
6
  import argparse
7
+ import os
8
 
9
  import gradio as gr
10
  from dotenv import load_dotenv
 
12
  from langchain.chains.combine_documents import create_stuff_documents_chain
13
  from langchain.schema import AIMessage, HumanMessage
14
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
+ from langchain.retrievers import ContextualCompressionRetriever
16
+ from langchain_cohere import CohereRerank
17
 
18
  import sage.vector_store as vector_store
19
  from sage.llm import build_llm_via_langchain
 
24
  def build_rag_chain(args):
25
  """Builds a RAG chain via LangChain."""
26
  llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
27
+
28
  retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
29
+ if args.reranker == "cohere":
30
+ compressor = CohereRerank(model="rerank-english-v3.0", cohere_api_key=os.environ.get("COHERE_API_KEY"))
31
+ retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
32
 
33
  # Prompt to contextualize the latest query based on the chat history.
34
  contextualize_q_system_prompt = (
 
89
  default="http://localhost:8882",
90
  help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
91
  )
92
+ parser.add_argument("--reranker", default="cohere", choices=["none", "cohere"])
93
  parser.add_argument(
94
  "--share",
95
  default=False,