Spaces:
Running
Running
Add Cohere reranker
Browse files- requirements.txt +7 -2
- sage/chat.py +9 -1
requirements.txt
CHANGED
|
@@ -1,12 +1,17 @@
|
|
| 1 |
GitPython==3.1.43
|
| 2 |
Pygments==2.18.0
|
|
|
|
| 3 |
fastapi==0.112.2
|
| 4 |
gradio>=4.26.0
|
| 5 |
-
langchain==0.2.
|
| 6 |
-
langchain-community==0.2.12
|
| 7 |
langchain-anthropic==0.1.23
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
langchain-ollama==0.1.2
|
| 9 |
langchain-openai==0.1.22
|
|
|
|
| 10 |
marqo==3.7.0
|
| 11 |
nbformat==5.10.4
|
| 12 |
openai==1.42.0
|
|
|
|
| 1 |
GitPython==3.1.43
|
| 2 |
Pygments==2.18.0
|
| 3 |
+
cohere==5.9.2
|
| 4 |
fastapi==0.112.2
|
| 5 |
gradio>=4.26.0
|
| 6 |
+
langchain==0.2.15
|
|
|
|
| 7 |
langchain-anthropic==0.1.23
|
| 8 |
+
langchain-cohere==0.2.4
|
| 9 |
+
langchain-community==0.2.12
|
| 10 |
+
langchain-core==0.2.36
|
| 11 |
+
langchain-experimental==0.0.64
|
| 12 |
langchain-ollama==0.1.2
|
| 13 |
langchain-openai==0.1.22
|
| 14 |
+
langchain-text-splitters==0.2.2
|
| 15 |
marqo==3.7.0
|
| 16 |
nbformat==5.10.4
|
| 17 |
openai==1.42.0
|
sage/chat.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
"""A gradio app that enables users to chat with their codebase.
|
| 2 |
|
| 3 |
-
You must run
|
| 4 |
"""
|
| 5 |
|
| 6 |
import argparse
|
|
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
from dotenv import load_dotenv
|
|
@@ -11,6 +12,8 @@ from langchain.chains import create_history_aware_retriever, create_retrieval_ch
|
|
| 11 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 12 |
from langchain.schema import AIMessage, HumanMessage
|
| 13 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
|
|
|
|
|
| 14 |
|
| 15 |
import sage.vector_store as vector_store
|
| 16 |
from sage.llm import build_llm_via_langchain
|
|
@@ -21,7 +24,11 @@ load_dotenv()
|
|
| 21 |
def build_rag_chain(args):
|
| 22 |
"""Builds a RAG chain via LangChain."""
|
| 23 |
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
|
|
|
|
| 24 |
retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Prompt to contextualize the latest query based on the chat history.
|
| 27 |
contextualize_q_system_prompt = (
|
|
@@ -82,6 +89,7 @@ def main():
|
|
| 82 |
default="http://localhost:8882",
|
| 83 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 84 |
)
|
|
|
|
| 85 |
parser.add_argument(
|
| 86 |
"--share",
|
| 87 |
default=False,
|
|
|
|
| 1 |
"""A gradio app that enables users to chat with their codebase.
|
| 2 |
|
| 3 |
+
You must run `sage-index $GITHUB_REPO` first in order to index the codebase into a vector store.
|
| 4 |
"""
|
| 5 |
|
| 6 |
import argparse
|
| 7 |
+
import os
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
from dotenv import load_dotenv
|
|
|
|
| 12 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 13 |
from langchain.schema import AIMessage, HumanMessage
|
| 14 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 15 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
| 16 |
+
from langchain_cohere import CohereRerank
|
| 17 |
|
| 18 |
import sage.vector_store as vector_store
|
| 19 |
from sage.llm import build_llm_via_langchain
|
|
|
|
| 24 |
def build_rag_chain(args):
|
| 25 |
"""Builds a RAG chain via LangChain."""
|
| 26 |
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
|
| 27 |
+
|
| 28 |
retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
|
| 29 |
+
if args.reranker == "cohere":
|
| 30 |
+
compressor = CohereRerank(model="rerank-english-v3.0", cohere_api_key=os.environ.get("COHERE_API_KEY"))
|
| 31 |
+
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
| 32 |
|
| 33 |
# Prompt to contextualize the latest query based on the chat history.
|
| 34 |
contextualize_q_system_prompt = (
|
|
|
|
| 89 |
default="http://localhost:8882",
|
| 90 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 91 |
)
|
| 92 |
+
parser.add_argument("--reranker", default="cohere", choices=["none", "cohere"])
|
| 93 |
parser.add_argument(
|
| 94 |
"--share",
|
| 95 |
default=False,
|