QABot / app.py
jyotiguptahk's picture
second
aa1bbfb
import os
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import TextLoader, DirectoryLoader
from langchain.embeddings import CohereEmbeddings
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms import OpenAI
from langchain.llms import Cohere
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import streamlit as st
def retrieve(query,llm,retriever):
template = """
Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.Use only the document for your answer and you may summarize the answer in 50 words to make it look better.
{context}
Question: {question}
"""
# create the chain to answer questions
qa_chain = RetrievalQA.from_chain_type(llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={
"prompt": PromptTemplate(
template=template,
input_variables=["context", "question"],)})
return(qa_chain.run(query))
def main():
# Main title of the application
st.title("Q&A BOT")
if 'counter' not in st.session_state:
st.session_state['counter'] = 0
with st.sidebar:
with st.form('Cohere/OpenAI'):
mod = st.radio('Choose OpenAI/Cohere', ('OpenAI', 'Cohere'))
api_key = st.text_input('Enter API key', type="password")
# model = st.radio('Choose Company', ('ArtisanAppetite foods', 'BMW','Titan Watches'))
submitted = st.form_submit_button("Submit")
if api_key:
if(mod=='OpenAI'):
os.environ["OPENAI_API_KEY"] = api_key
llm = OpenAI(temperature=0.7, verbose=True)
embeddings = OpenAIEmbeddings()
elif(mod=='Cohere'):
os.environ["COHERE_API_KEY"] = api_key
llm = Cohere(temperature=0.7, verbose=True)
embeddings = CohereEmbeddings()
uploaded_file = st.file_uploader("Upload a file to ingest", type=["txt"])
if uploaded_file is not None:
file_path = uploaded_file.name
print(file_path)
# this is a necessary step to read the file content and save it
# in the webservers location
file_contents = uploaded_file.read()
save_path = uploaded_file.name
with open(save_path, "wb") as f:
f.write(file_contents)
print(save_path)
loader = TextLoader(save_path,autodetect_encoding=True)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000) #Splitting the text and creating chunks
docs = text_splitter.split_documents(documents)
persist_directory = save_path[:-4]
vectordb = Chroma.from_documents(documents=docs,
embedding=embeddings,
persist_directory=persist_directory)
# persiste the db to disk
vectordb.persist()
retriever = vectordb.as_retriever(search_kwargs={"k": 3})
st.session_state['counter'] += 1
if st.session_state['counter'] > 0:
query = st.text_input("Query: ", "", key="input")
result_display = st.empty()
if query is not None and query != "":
# create a retriever
result = retrieve(query,llm,retriever)
# Text area for editing the generated response
result_display.text_area("Result:", value=result, height=500)
elif (not api_key):
st.info("Please add configuration details in left panel")
st.stop()
if __name__ == "__main__":
main()