File size: 4,931 Bytes
6e0da70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import pathlib
import requests
import subprocess
import streamlit as st
from streamlit_chat import message
import openai
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import PythonCodeTextSplitter
from langchain.llms import OpenAI
from langchain.schema import Document
from langchain.chains.question_answering import load_qa_chain
from langchain.chains import RetrievalQA

# Hide traceback
st.set_option('client.showErrorDetails', False)

# Setting page title and header
st.set_page_config(page_title="CODE CHAT", page_icon=":robot_face:")
st.markdown("<h1 style='text-align: center; color: red;'>CODE CHAT</h1>", unsafe_allow_html=True)
st.markdown("<h3 style='text-align: center;'>Ask questions about our GitHub repositories</h3>", unsafe_allow_html=True)

# Initialise session state variables
if 'generated' not in st.session_state:
    st.session_state['generated'] = []
if 'past' not in st.session_state:
    st.session_state['past'] = []

# Predefine GitHub username and token (replace with your company's GitHub details)
GITHUB_USERNAME = 'your_company_github_username'
GITHUB_TOKEN = 'your_github_token'
openai_api_key = st.text_input("Enter your OpenAI API Key", type='password')

# Create a button for the user to submit their API key
if st.button('Submit'):
    # Set the OpenAI API key as an environment variable
    os.environ["OPENAI_API_KEY"] = openai_api_key
    openai.api_key = openai_api_key
    
    # Check if the API key is valid by making a simple API call
    try:
        models = openai.Model.list()
        st.success("API key is valid!")
    except Exception as e:
        st.error(f"Error testing API key: {e}")

def get_github_repos(username, token):
    url = f"https://api.github.com/users/{username}/repos"
    headers = {"Authorization": f"token {token}"}
    response = requests.get(url, headers=headers)
    repos = response.json()
    return [repo['clone_url'] for repo in repos]

def clone_repo(repo_url, clone_dir):
    repo_name = repo_url.split("/")[-1].replace('.git', '')
    repo_path = os.path.join(clone_dir, repo_name)
    if not os.path.exists(repo_path):
        subprocess.run(["git", "clone", repo_url, repo_path])
    return repo_path

def get_repo_docs(repo_path):
    repo = pathlib.Path(repo_path)
    for codefile in repo.glob("**/*.ipynb"):
        with open(codefile, "r") as file:
            rel_path = codefile.relative_to(repo)
            yield Document(page_content=file.read(), metadata={"source": str(rel_path)})

def get_all_repo_docs(clone_dir):
    repo_paths = [os.path.join(clone_dir, repo) for repo in os.listdir(clone_dir)]
    for repo_path in repo_paths:
        yield from get_repo_docs(repo_path)

def get_source_chunks_from_repos(clone_dir):
    source_chunks = []
    splitter = PythonCodeTextSplitter(chunk_size=1024, chunk_overlap=30)
    for source in get_all_repo_docs(clone_dir):
        for chunk in splitter.split_text(source.page_content):
            source_chunks.append(Document(page_content=chunk, metadata=source.metadata))
    return source_chunks

def generate_response(input_text):
    CLONE_DIR = './cloned_repos'
    CHROMA_DB_PATH = f'./chroma/{GITHUB_USERNAME}_repos'

    # Fetch all repos and clone them locally
    repo_urls = get_github_repos(GITHUB_USERNAME, GITHUB_TOKEN)
    for repo_url in repo_urls:
        clone_repo(repo_url, CLONE_DIR)

    vector_db = None

    if not os.path.exists(CHROMA_DB_PATH):
        source_chunks = get_source_chunks_from_repos(CLONE_DIR)
        vector_db = Chroma.from_documents(source_chunks, OpenAIEmbeddings(), persist_directory=CHROMA_DB_PATH) 
        vector_db.persist()
    else:
        vector_db = Chroma(persist_directory=CHROMA_DB_PATH, embedding_function=OpenAIEmbeddings())

    qa_chain = load_qa_chain(OpenAI(temperature=1), chain_type="stuff")
    qa = RetrievalQA(combine_documents_chain=qa_chain, retriever=vector_db.as_retriever())
    query_response = qa.run(input_text)
    return query_response

response_container = st.container()

input_container = st.container()

with input_container:
    with st.form(key='my_form', clear_on_submit=True):
        user_input = st.text_area("You:", key='input', height=100)
        submit_button = st.form_submit_button(label='Send')

    if submit_button and user_input:
        try:
            query_response = generate_response(user_input)
            st.session_state['past'].append(user_input)
            st.session_state['generated'].append(query_response)
        except Exception as e:
            st.error(f"An error occurred: {e}")

if st.session_state['generated']:
    with response_container:
        for i in range(len(st.session_state['generated'])):
            message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
            st.code(st.session_state["generated"][i], language="python", line_numbers=False)