chat-bot / app.py
wesam0099's picture
Upload 2 files
6e0da70 verified
import os
import pathlib
import requests
import subprocess
import streamlit as st
from streamlit_chat import message
import openai
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import PythonCodeTextSplitter
from langchain.llms import OpenAI
from langchain.schema import Document
from langchain.chains.question_answering import load_qa_chain
from langchain.chains import RetrievalQA
# Hide traceback
st.set_option('client.showErrorDetails', False)
# Setting page title and header
st.set_page_config(page_title="CODE CHAT", page_icon=":robot_face:")
st.markdown("<h1 style='text-align: center; color: red;'>CODE CHAT</h1>", unsafe_allow_html=True)
st.markdown("<h3 style='text-align: center;'>Ask questions about our GitHub repositories</h3>", unsafe_allow_html=True)
# Initialise session state variables
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
# Predefine GitHub username and token (replace with your company's GitHub details)
GITHUB_USERNAME = 'your_company_github_username'
GITHUB_TOKEN = 'your_github_token'
openai_api_key = st.text_input("Enter your OpenAI API Key", type='password')
# Create a button for the user to submit their API key
if st.button('Submit'):
# Set the OpenAI API key as an environment variable
os.environ["OPENAI_API_KEY"] = openai_api_key
openai.api_key = openai_api_key
# Check if the API key is valid by making a simple API call
try:
models = openai.Model.list()
st.success("API key is valid!")
except Exception as e:
st.error(f"Error testing API key: {e}")
def get_github_repos(username, token):
url = f"https://api.github.com/users/{username}/repos"
headers = {"Authorization": f"token {token}"}
response = requests.get(url, headers=headers)
repos = response.json()
return [repo['clone_url'] for repo in repos]
def clone_repo(repo_url, clone_dir):
repo_name = repo_url.split("/")[-1].replace('.git', '')
repo_path = os.path.join(clone_dir, repo_name)
if not os.path.exists(repo_path):
subprocess.run(["git", "clone", repo_url, repo_path])
return repo_path
def get_repo_docs(repo_path):
repo = pathlib.Path(repo_path)
for codefile in repo.glob("**/*.ipynb"):
with open(codefile, "r") as file:
rel_path = codefile.relative_to(repo)
yield Document(page_content=file.read(), metadata={"source": str(rel_path)})
def get_all_repo_docs(clone_dir):
repo_paths = [os.path.join(clone_dir, repo) for repo in os.listdir(clone_dir)]
for repo_path in repo_paths:
yield from get_repo_docs(repo_path)
def get_source_chunks_from_repos(clone_dir):
source_chunks = []
splitter = PythonCodeTextSplitter(chunk_size=1024, chunk_overlap=30)
for source in get_all_repo_docs(clone_dir):
for chunk in splitter.split_text(source.page_content):
source_chunks.append(Document(page_content=chunk, metadata=source.metadata))
return source_chunks
def generate_response(input_text):
CLONE_DIR = './cloned_repos'
CHROMA_DB_PATH = f'./chroma/{GITHUB_USERNAME}_repos'
# Fetch all repos and clone them locally
repo_urls = get_github_repos(GITHUB_USERNAME, GITHUB_TOKEN)
for repo_url in repo_urls:
clone_repo(repo_url, CLONE_DIR)
vector_db = None
if not os.path.exists(CHROMA_DB_PATH):
source_chunks = get_source_chunks_from_repos(CLONE_DIR)
vector_db = Chroma.from_documents(source_chunks, OpenAIEmbeddings(), persist_directory=CHROMA_DB_PATH)
vector_db.persist()
else:
vector_db = Chroma(persist_directory=CHROMA_DB_PATH, embedding_function=OpenAIEmbeddings())
qa_chain = load_qa_chain(OpenAI(temperature=1), chain_type="stuff")
qa = RetrievalQA(combine_documents_chain=qa_chain, retriever=vector_db.as_retriever())
query_response = qa.run(input_text)
return query_response
response_container = st.container()
input_container = st.container()
with input_container:
with st.form(key='my_form', clear_on_submit=True):
user_input = st.text_area("You:", key='input', height=100)
submit_button = st.form_submit_button(label='Send')
if submit_button and user_input:
try:
query_response = generate_response(user_input)
st.session_state['past'].append(user_input)
st.session_state['generated'].append(query_response)
except Exception as e:
st.error(f"An error occurred: {e}")
if st.session_state['generated']:
with response_container:
for i in range(len(st.session_state['generated'])):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
st.code(st.session_state["generated"][i], language="python", line_numbers=False)