File size: 4,567 Bytes
559dd34
 
 
 
 
 
 
 
57007fe
 
 
559dd34
 
 
5f9eeb4
559dd34
5f9eeb4
559dd34
 
 
 
 
 
 
 
5f9eeb4
559dd34
 
 
 
 
 
 
 
 
 
 
 
 
 
57007fe
559dd34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57007fe
559dd34
 
 
 
 
 
 
d5c979a
 
 
559dd34
57007fe
 
559dd34
57007fe
 
 
559dd34
d5c979a
 
 
 
 
559dd34
 
 
 
 
 
 
 
 
 
 
57007fe
559dd34
 
 
d5c979a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""A gradio app that enables users to chat with their codebase.

You must run main.py first in order to index the codebase into a vector store.
"""

import argparse

import gradio as gr
from dotenv import load_dotenv
from langchain.chains import (create_history_aware_retriever,
                              create_retrieval_chain)
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.schema import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI

import vector_store
from repo_manager import RepoManager

load_dotenv()


def build_rag_chain(args):
    """Builds a RAG chain via LangChain."""
    llm = ChatOpenAI(model=args.openai_model)
    retriever = vector_store.build_from_args(args).to_langchain().as_retriever()

    # Prompt to contextualize the latest query based on the chat history.
    contextualize_q_system_prompt = (
        "Given a chat history and the latest user question which might reference context in the chat history, "
        "formualte a standalone question which can be understood without the chat history. Do NOT answer the question, "
        "just reformulate it if needed and otherwise return it as is."
    )
    contextualize_q_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", contextualize_q_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ]
    )
    history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)

    qa_system_prompt = (
        f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
        "Assume I am an advanced developer and answer my questions in the most succinct way possible."
        "\n\n"
        "Here are some snippets from the codebase."
        "\n\n"
        "{context}"
    )
    qa_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", qa_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ]
    )

    question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
    rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
    return rag_chain


def append_sources_to_response(response):
    """Given an OpenAI completion response, appends to it GitHub links of the context sources."""
    filenames = [document.metadata["filename"] for document in response["context"]]
    # Deduplicate filenames while preserving their order.
    filenames = list(dict.fromkeys(filenames))
    repo_manager = RepoManager(args.repo_id)
    github_links = [repo_manager.github_link_for_file(filename) for filename in filenames]
    return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="UI to chat with your codebase")
    parser.add_argument("repo_id", help="The ID of the repository to index")
    parser.add_argument(
        "--openai_model",
        default="gpt-4",
        help="The OpenAI model to use for response generation",
    )
    parser.add_argument("--vector_store_type", default="pinecone", choices=["pinecone", "marqo"])
    parser.add_argument("--index_name", required=True, help="Vector store index name")
    parser.add_argument(
        "--marqo_url",
        default="http://localhost:8882",
        help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
    )
    parser.add_argument(
        "--share",
        default=False,
        help="Whether to make the gradio app publicly accessible.",
    )
    args = parser.parse_args()

    rag_chain = build_rag_chain(args)

    def _predict(message, history):
        """Performs one RAG operation."""
        history_langchain_format = []
        for human, ai in history:
            history_langchain_format.append(HumanMessage(content=human))
            history_langchain_format.append(AIMessage(content=ai))
        history_langchain_format.append(HumanMessage(content=message))
        response = rag_chain.invoke({"input": message, "chat_history": history_langchain_format})
        answer = append_sources_to_response(response)
        return answer

    gr.ChatInterface(
        _predict,
        title=args.repo_id,
        description=f"Code sage for your repo: {args.repo_id}",
        examples=["What does this repo do?", "Give me some sample code."],
    ).launch(share=args.share)