Spaces:
Running
Running
File size: 4,567 Bytes
559dd34 57007fe 559dd34 5f9eeb4 559dd34 5f9eeb4 559dd34 5f9eeb4 559dd34 57007fe 559dd34 57007fe 559dd34 d5c979a 559dd34 57007fe 559dd34 57007fe 559dd34 d5c979a 559dd34 57007fe 559dd34 d5c979a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | """A gradio app that enables users to chat with their codebase.
You must run main.py first in order to index the codebase into a vector store.
"""
import argparse
import gradio as gr
from dotenv import load_dotenv
from langchain.chains import (create_history_aware_retriever,
create_retrieval_chain)
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.schema import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
import vector_store
from repo_manager import RepoManager
load_dotenv()
def build_rag_chain(args):
"""Builds a RAG chain via LangChain."""
llm = ChatOpenAI(model=args.openai_model)
retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
# Prompt to contextualize the latest query based on the chat history.
contextualize_q_system_prompt = (
"Given a chat history and the latest user question which might reference context in the chat history, "
"formualte a standalone question which can be understood without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
qa_system_prompt = (
f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
"Assume I am an advanced developer and answer my questions in the most succinct way possible."
"\n\n"
"Here are some snippets from the codebase."
"\n\n"
"{context}"
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
return rag_chain
def append_sources_to_response(response):
"""Given an OpenAI completion response, appends to it GitHub links of the context sources."""
filenames = [document.metadata["filename"] for document in response["context"]]
# Deduplicate filenames while preserving their order.
filenames = list(dict.fromkeys(filenames))
repo_manager = RepoManager(args.repo_id)
github_links = [repo_manager.github_link_for_file(filename) for filename in filenames]
return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="UI to chat with your codebase")
parser.add_argument("repo_id", help="The ID of the repository to index")
parser.add_argument(
"--openai_model",
default="gpt-4",
help="The OpenAI model to use for response generation",
)
parser.add_argument("--vector_store_type", default="pinecone", choices=["pinecone", "marqo"])
parser.add_argument("--index_name", required=True, help="Vector store index name")
parser.add_argument(
"--marqo_url",
default="http://localhost:8882",
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
)
parser.add_argument(
"--share",
default=False,
help="Whether to make the gradio app publicly accessible.",
)
args = parser.parse_args()
rag_chain = build_rag_chain(args)
def _predict(message, history):
"""Performs one RAG operation."""
history_langchain_format = []
for human, ai in history:
history_langchain_format.append(HumanMessage(content=human))
history_langchain_format.append(AIMessage(content=ai))
history_langchain_format.append(HumanMessage(content=message))
response = rag_chain.invoke({"input": message, "chat_history": history_langchain_format})
answer = append_sources_to_response(response)
return answer
gr.ChatInterface(
_predict,
title=args.repo_id,
description=f"Code sage for your repo: {args.repo_id}",
examples=["What does this repo do?", "Give me some sample code."],
).launch(share=args.share)
|