code-crawler / sage /chat.py
juliaturc's picture
Add an LLM retriever (#70)
27dd60a
raw
history blame
4.8 kB
"""A gradio app that enables users to chat with their codebase.
You must run `sage-index $GITHUB_REPO` first in order to index the codebase into a vector store.
"""
import logging
import configargparse
import gradio as gr
from dotenv import load_dotenv
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.schema import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import sage.config as sage_config
from sage.llm import build_llm_via_langchain
from sage.retriever import build_retriever_from_args
load_dotenv()
def build_rag_chain(args):
"""Builds a RAG chain via LangChain."""
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
retriever = build_retriever_from_args(args)
# Prompt to contextualize the latest query based on the chat history.
contextualize_q_system_prompt = (
"Given a chat history and the latest user question which might reference context in the chat history, "
"formulate a standalone question which can be understood without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
contextualize_q_llm = llm.with_config(tags=["contextualize_q_llm"])
history_aware_retriever = create_history_aware_retriever(contextualize_q_llm, retriever, contextualize_q_prompt)
qa_system_prompt = (
f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
"Assume I am an advanced developer and answer my questions in the most succinct way possible."
"\n\n"
"Here are some snippets from the codebase."
"\n\n"
"{context}"
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
return rag_chain
def main():
parser = configargparse.ArgParser(
description="Batch-embeds a GitHub repository and its issues.", ignore_unknown_config_file_keys=True
)
parser.add(
"--share",
default=False,
help="Whether to make the gradio app publicly accessible.",
)
validator = sage_config.add_all_args(parser)
args = parser.parse_args()
validator(args)
rag_chain = build_rag_chain(args)
def source_md(file_path: str, url: str) -> str:
"""Formats a context source in Markdown."""
return f"[{file_path}]({url})"
async def _predict(message, history):
"""Performs one RAG operation."""
history_langchain_format = []
for human, ai in history:
history_langchain_format.append(HumanMessage(content=human))
history_langchain_format.append(AIMessage(content=ai))
history_langchain_format.append(HumanMessage(content=message))
query_rewrite = ""
response = ""
async for event in rag_chain.astream_events(
{
"input": message,
"chat_history": history_langchain_format,
},
version="v1",
):
if event["name"] == "retrieve_documents" and "output" in event["data"]:
sources = [(doc.metadata["file_path"], doc.metadata["url"]) for doc in event["data"]["output"]]
# Deduplicate while preserving the order.
sources = list(dict.fromkeys(sources))
response += "## Sources:\n" + "\n".join([source_md(s[0], s[1]) for s in sources]) + "\n## Response:\n"
elif event["event"] == "on_chat_model_stream":
chunk = event["data"]["chunk"].content
if "contextualize_q_llm" in event["tags"]:
query_rewrite += chunk
else:
# This is the actual response to the user query.
if not response:
logging.info(f"Query rewrite: {query_rewrite}")
response += chunk
yield response
gr.ChatInterface(
_predict,
title=args.repo_id,
examples=["What does this repo do?", "Give me some sample code."],
).launch(share=args.share)
if __name__ == "__main__":
main()