juliaturc commited on
Commit
c010627
·
1 Parent(s): 3d780b8

Add Hugging Face cross-encoders

Browse files
Files changed (2) hide show
  1. requirements.txt +12 -9
  2. sage/chat.py +26 -5
requirements.txt CHANGED
@@ -3,15 +3,15 @@ Pygments==2.18.0
3
  cohere==5.9.2
4
  fastapi==0.112.2
5
  gradio>=4.26.0
6
- langchain==0.2.15
7
- langchain-anthropic==0.1.23
8
- langchain-cohere==0.2.4
9
- langchain-community==0.2.12
10
- langchain-core==0.2.36
11
- langchain-experimental==0.0.64
12
- langchain-ollama==0.1.2
13
- langchain-openai==0.1.22
14
- langchain-text-splitters==0.2.2
15
  marqo==3.7.0
16
  nbformat==5.10.4
17
  openai==1.42.0
@@ -19,6 +19,9 @@ pinecone==5.0.1
19
  python-dotenv==1.0.1
20
  requests==2.32.3
21
  semchunk==2.2.0
 
22
  tiktoken==0.7.0
 
 
23
  tree-sitter==0.22.3
24
  tree-sitter-language-pack==0.2.0
 
3
  cohere==5.9.2
4
  fastapi==0.112.2
5
  gradio>=4.26.0
6
+ langchain==0.3.0
7
+ langchain-anthropic==0.2.0
8
+ langchain-cohere==0.3.0
9
+ langchain-community==0.3.0
10
+ langchain-core==0.3.0
11
+ langchain-experimental==0.3.0
12
+ langchain-ollama==0.2.0
13
+ langchain-openai==0.2.0
14
+ langchain-text-splitters==0.3.0
15
  marqo==3.7.0
16
  nbformat==5.10.4
17
  openai==1.42.0
 
19
  python-dotenv==1.0.1
20
  requests==2.32.3
21
  semchunk==2.2.0
22
+ sentence-transformers==3.1.0
23
  tiktoken==0.7.0
24
+ tokenizers==0.19.1
25
+ transformers==4.44.2
26
  tree-sitter==0.22.3
27
  tree-sitter-language-pack==0.2.0
sage/chat.py CHANGED
@@ -10,10 +10,12 @@ import gradio as gr
10
  from dotenv import load_dotenv
11
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
12
  from langchain.chains.combine_documents import create_stuff_documents_chain
13
- from langchain.schema import AIMessage, HumanMessage
14
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
  from langchain.retrievers import ContextualCompressionRetriever
 
 
16
  from langchain_cohere import CohereRerank
 
 
17
 
18
  import sage.vector_store as vector_store
19
  from sage.llm import build_llm_via_langchain
@@ -26,8 +28,16 @@ def build_rag_chain(args):
26
  llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
27
 
28
  retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
29
- if args.reranker == "cohere":
30
- compressor = CohereRerank(model="rerank-english-v3.0", cohere_api_key=os.environ.get("COHERE_API_KEY"))
 
 
 
 
 
 
 
 
31
  retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
32
 
33
  # Prompt to contextualize the latest query based on the chat history.
@@ -89,7 +99,12 @@ def main():
89
  default="http://localhost:8882",
90
  help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
91
  )
92
- parser.add_argument("--reranker", default="cohere", choices=["none", "cohere"])
 
 
 
 
 
93
  parser.add_argument(
94
  "--share",
95
  default=False,
@@ -113,6 +128,12 @@ def main():
113
  else:
114
  raise ValueError("Please specify --llm_model")
115
 
 
 
 
 
 
 
116
  rag_chain = build_rag_chain(args)
117
 
118
  def _predict(message, history):
 
10
  from dotenv import load_dotenv
11
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
12
  from langchain.chains.combine_documents import create_stuff_documents_chain
 
 
13
  from langchain.retrievers import ContextualCompressionRetriever
14
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
15
+ from langchain.schema import AIMessage, HumanMessage
16
  from langchain_cohere import CohereRerank
17
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
18
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
19
 
20
  import sage.vector_store as vector_store
21
  from sage.llm import build_llm_via_langchain
 
28
  llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
29
 
30
  retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
31
+
32
+ if args.reranker_provider == "none":
33
+ compressor = None
34
+ if args.reranker_provider == "huggingface":
35
+ encoder_model = HuggingFaceCrossEncoder(model_name=args.reranker_model)
36
+ compressor = CrossEncoderReranker(model=encoder_model, top_n=5)
37
+ if args.reranker_provider == "cohere":
38
+ compressor = CohereRerank(model=args.reranker_model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=5)
39
+
40
+ if compressor:
41
  retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
42
 
43
  # Prompt to contextualize the latest query based on the chat history.
 
99
  default="http://localhost:8882",
100
  help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
101
  )
102
+ parser.add_argument("--reranker-provider", default="huggingface", choices=["none", "huggingface", "cohere"])
103
+ parser.add_argument(
104
+ "--reranker-model",
105
+ help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
106
+ "SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models",
107
+ )
108
  parser.add_argument(
109
  "--share",
110
  default=False,
 
128
  else:
129
  raise ValueError("Please specify --llm_model")
130
 
131
+ if not args.reranker_model:
132
+ if args.reranker_provider == "cohere":
133
+ args.reranker_model = "rerank-english-v3.0"
134
+ elif args.reranker_provider == "huggingface":
135
+ args.reranker_model = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
136
+
137
  rag_chain = build_rag_chain(args)
138
 
139
  def _predict(message, history):