File size: 4,038 Bytes
384e185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df9de08
 
 
384e185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64e59d0
aa1bbfb
 
384e185
64e59d0
aa1bbfb
 
64e59d0
 
 
 
 
 
 
df9de08
 
 
 
64e59d0
df9de08
 
 
 
 
7429bf2
384e185
df9de08
384e185
df9de08
384e185
 
 
7429bf2
84644d5
5853234
7429bf2
5853234
 
 
384e185
 
 
aa1bbfb
384e185
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

import os
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import TextLoader, DirectoryLoader
from langchain.embeddings import CohereEmbeddings
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms import OpenAI
from langchain.llms import Cohere
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

import streamlit as st

def retrieve(query,llm,retriever):

    template = """
    Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.Use only the document for your answer and you may summarize the answer in 50 words to make it look better.

    {context}

    Question: {question}
    """
    # create the chain to answer questions
    qa_chain = RetrievalQA.from_chain_type(llm=llm,
                                  chain_type="stuff",
                                  retriever=retriever,
                                  chain_type_kwargs={
                                  "prompt": PromptTemplate(
                                  template=template,
                                  input_variables=["context", "question"],)})

    return(qa_chain.run(query))


def main():

    # Main title of the application
    st.title("Q&A BOT")

    if 'counter' not in st.session_state:
            st.session_state['counter'] = 0

    with st.sidebar:
        with st.form('Cohere/OpenAI'):
            mod = st.radio('Choose OpenAI/Cohere', ('OpenAI', 'Cohere'))
            api_key = st.text_input('Enter API key', type="password")
            # model = st.radio('Choose Company', ('ArtisanAppetite foods', 'BMW','Titan Watches'))
            submitted = st.form_submit_button("Submit")

    if api_key:
        if(mod=='OpenAI'):
            os.environ["OPENAI_API_KEY"] = api_key
            llm = OpenAI(temperature=0.7, verbose=True)
            embeddings = OpenAIEmbeddings()
        elif(mod=='Cohere'):
            os.environ["COHERE_API_KEY"] = api_key
            llm = Cohere(temperature=0.7, verbose=True)
            embeddings = CohereEmbeddings()

        uploaded_file = st.file_uploader("Upload a file to ingest", type=["txt"])

        if uploaded_file is not None:


            file_path = uploaded_file.name
            print(file_path)

            # this is a necessary step to read the file content and save it 
            # in the webservers location
            file_contents = uploaded_file.read()
            save_path = uploaded_file.name
            with open(save_path, "wb") as f:
                f.write(file_contents)
            print(save_path)

            loader = TextLoader(save_path,autodetect_encoding=True)
            documents = loader.load()
            text_splitter = CharacterTextSplitter(chunk_size=1000) #Splitting the text and creating chunks
            docs = text_splitter.split_documents(documents)

            persist_directory = save_path[:-4]
            vectordb = Chroma.from_documents(documents=docs,
                                        embedding=embeddings,
                                        persist_directory=persist_directory)
            # persiste the db to disk
            vectordb.persist()
            retriever = vectordb.as_retriever(search_kwargs={"k": 3})

            st.session_state['counter'] += 1

        if st.session_state['counter'] > 0:
            query = st.text_input("Query: ", "", key="input")
            result_display = st.empty()

            if query is not None and query != "":

                # create a retriever
                
                result = retrieve(query,llm,retriever)
                # Text area for editing the generated response
                result_display.text_area("Result:", value=result, height=500)

    elif (not api_key):
        st.info("Please add configuration details in left panel")
        st.stop() 

if __name__ == "__main__":
    main()