File size: 3,066 Bytes
b245976
37fd533
9931049
b245976
1863eda
b245976
 
 
 
 
 
 
9ebda3d
cb82e54
4762f17
37fd533
b245976
 
5f60d29
b245976
 
 
9931049
b245976
9931049
 
67eb5ac
37fd533
9931049
71800a6
9931049
 
 
 
b245976
 
9931049
b245976
 
 
 
 
 
 
 
 
 
 
9931049
b245976
 
 
 
 
 
 
9931049
b245976
 
 
 
 
9931049
 
b245976
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71800a6
b245976
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st 
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import pipeline
import torch 

from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS 
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
#from constants import CHROMA_SETTINGS
from streamlit_chat import message
import safetensors

checkpoint = "LaMini-Flan-T5-77M"
tokenizer = T5Tokenizer.from_pretrained(checkpoint)

base_model = T5ForConditionalGeneration.from_pretrained( 
    checkpoint,
      device_map = 'auto',
        torch_dtype = torch.float32,
        )


@st.cache_resource
def llm_pipeline():
    pipe = pipeline(
        'text2text-generation',
        model = base_model,
        tokenizer = tokenizer,
        temperature = 0.5
    )
    local_llm = HuggingFacePipeline(pipeline=pipe)
    return local_llm

@st.cache_resource
def qa_llm():
    llm = llm_pipeline()
    embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
    db = FAISS.load_local("vector_data",embeddings)
    #db = Chroma(persist_directory="db", embedding_function = embeddings, client_settings=CHROMA_SETTINGS)
    retriever = db.as_retriever()
    qa = RetrievalQA.from_chain_type(
        llm = llm,
        chain_type = "stuff",
        retriever = retriever,
        return_source_documents=True
    )
    return qa

def process_answer(instruction):
    response = ''
    instruction = instruction
    qa = qa_llm()
    generated_text = qa(instruction)
    answer = generated_text['result']
    return answer

# Display conversation history using Streamlit messages
def display_conversation(history):
    for i in range(len(history["generated"])):
        message(history["past"][i], is_user=True, key=str(i) + "_user")
        message(history["generated"][i],key=str(i))


def main():
    st.title('Chat with Your Data πŸ¦œπŸ“„')
    with st.expander("About the Chatbot"):
        st.markdown(
            """
            This is a Generative AI powered Chatbot that interacts with you and you can ask followup questions.
            """
        )

    user_input = st.text_input("Question:", placeholder="Ask about your PDF", key='input')
    with st.form(key='my_form', clear_on_submit=True):
        submit_button = st.form_submit_button(label='Send')

     # Initialize session state for generated responses and past messages
    if "generated" not in st.session_state:
        st.session_state["generated"] = ["I am ready to help you"]
    if "past" not in st.session_state:
        st.session_state["past"] = ["Hey there!πŸ‘‹"]


    if submit_button and user_input or user_input :
        st.session_state['past'].append(user_input)
        with st.spinner('Generating response...'):
            answer = process_answer({'query': user_input})
        st.session_state['generated'].append(answer)

    if st.session_state["generated"]:
       display_conversation(st.session_state)

  
if __name__ == '__main__':
    main()