Spaces:
Running
Running
Add Hugging Face cross-encoders
Browse files- requirements.txt +12 -9
- sage/chat.py +26 -5
requirements.txt
CHANGED
|
@@ -3,15 +3,15 @@ Pygments==2.18.0
|
|
| 3 |
cohere==5.9.2
|
| 4 |
fastapi==0.112.2
|
| 5 |
gradio>=4.26.0
|
| 6 |
-
langchain==0.
|
| 7 |
-
langchain-anthropic==0.
|
| 8 |
-
langchain-cohere==0.
|
| 9 |
-
langchain-community==0.
|
| 10 |
-
langchain-core==0.
|
| 11 |
-
langchain-experimental==0.0
|
| 12 |
-
langchain-ollama==0.
|
| 13 |
-
langchain-openai==0.
|
| 14 |
-
langchain-text-splitters==0.
|
| 15 |
marqo==3.7.0
|
| 16 |
nbformat==5.10.4
|
| 17 |
openai==1.42.0
|
|
@@ -19,6 +19,9 @@ pinecone==5.0.1
|
|
| 19 |
python-dotenv==1.0.1
|
| 20 |
requests==2.32.3
|
| 21 |
semchunk==2.2.0
|
|
|
|
| 22 |
tiktoken==0.7.0
|
|
|
|
|
|
|
| 23 |
tree-sitter==0.22.3
|
| 24 |
tree-sitter-language-pack==0.2.0
|
|
|
|
| 3 |
cohere==5.9.2
|
| 4 |
fastapi==0.112.2
|
| 5 |
gradio>=4.26.0
|
| 6 |
+
langchain==0.3.0
|
| 7 |
+
langchain-anthropic==0.2.0
|
| 8 |
+
langchain-cohere==0.3.0
|
| 9 |
+
langchain-community==0.3.0
|
| 10 |
+
langchain-core==0.3.0
|
| 11 |
+
langchain-experimental==0.3.0
|
| 12 |
+
langchain-ollama==0.2.0
|
| 13 |
+
langchain-openai==0.2.0
|
| 14 |
+
langchain-text-splitters==0.3.0
|
| 15 |
marqo==3.7.0
|
| 16 |
nbformat==5.10.4
|
| 17 |
openai==1.42.0
|
|
|
|
| 19 |
python-dotenv==1.0.1
|
| 20 |
requests==2.32.3
|
| 21 |
semchunk==2.2.0
|
| 22 |
+
sentence-transformers==3.1.0
|
| 23 |
tiktoken==0.7.0
|
| 24 |
+
tokenizers==0.19.1
|
| 25 |
+
transformers==4.44.2
|
| 26 |
tree-sitter==0.22.3
|
| 27 |
tree-sitter-language-pack==0.2.0
|
sage/chat.py
CHANGED
|
@@ -10,10 +10,12 @@ import gradio as gr
|
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
| 12 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 13 |
-
from langchain.schema import AIMessage, HumanMessage
|
| 14 |
-
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 15 |
from langchain.retrievers import ContextualCompressionRetriever
|
|
|
|
|
|
|
| 16 |
from langchain_cohere import CohereRerank
|
|
|
|
|
|
|
| 17 |
|
| 18 |
import sage.vector_store as vector_store
|
| 19 |
from sage.llm import build_llm_via_langchain
|
|
@@ -26,8 +28,16 @@ def build_rag_chain(args):
|
|
| 26 |
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
|
| 27 |
|
| 28 |
retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
| 32 |
|
| 33 |
# Prompt to contextualize the latest query based on the chat history.
|
|
@@ -89,7 +99,12 @@ def main():
|
|
| 89 |
default="http://localhost:8882",
|
| 90 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 91 |
)
|
| 92 |
-
parser.add_argument("--reranker", default="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
parser.add_argument(
|
| 94 |
"--share",
|
| 95 |
default=False,
|
|
@@ -113,6 +128,12 @@ def main():
|
|
| 113 |
else:
|
| 114 |
raise ValueError("Please specify --llm_model")
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
rag_chain = build_rag_chain(args)
|
| 117 |
|
| 118 |
def _predict(message, history):
|
|
|
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
| 12 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
|
|
|
|
|
| 13 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 14 |
+
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
| 15 |
+
from langchain.schema import AIMessage, HumanMessage
|
| 16 |
from langchain_cohere import CohereRerank
|
| 17 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 18 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 19 |
|
| 20 |
import sage.vector_store as vector_store
|
| 21 |
from sage.llm import build_llm_via_langchain
|
|
|
|
| 28 |
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
|
| 29 |
|
| 30 |
retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
|
| 31 |
+
|
| 32 |
+
if args.reranker_provider == "none":
|
| 33 |
+
compressor = None
|
| 34 |
+
if args.reranker_provider == "huggingface":
|
| 35 |
+
encoder_model = HuggingFaceCrossEncoder(model_name=args.reranker_model)
|
| 36 |
+
compressor = CrossEncoderReranker(model=encoder_model, top_n=5)
|
| 37 |
+
if args.reranker_provider == "cohere":
|
| 38 |
+
compressor = CohereRerank(model=args.reranker_model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=5)
|
| 39 |
+
|
| 40 |
+
if compressor:
|
| 41 |
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
| 42 |
|
| 43 |
# Prompt to contextualize the latest query based on the chat history.
|
|
|
|
| 99 |
default="http://localhost:8882",
|
| 100 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 101 |
)
|
| 102 |
+
parser.add_argument("--reranker-provider", default="huggingface", choices=["none", "huggingface", "cohere"])
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--reranker-model",
|
| 105 |
+
help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
|
| 106 |
+
"SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models",
|
| 107 |
+
)
|
| 108 |
parser.add_argument(
|
| 109 |
"--share",
|
| 110 |
default=False,
|
|
|
|
| 128 |
else:
|
| 129 |
raise ValueError("Please specify --llm_model")
|
| 130 |
|
| 131 |
+
if not args.reranker_model:
|
| 132 |
+
if args.reranker_provider == "cohere":
|
| 133 |
+
args.reranker_model = "rerank-english-v3.0"
|
| 134 |
+
elif args.reranker_provider == "huggingface":
|
| 135 |
+
args.reranker_model = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
|
| 136 |
+
|
| 137 |
rag_chain = build_rag_chain(args)
|
| 138 |
|
| 139 |
def _predict(message, history):
|