Aditya757864 commited on
Commit
b245976
Β·
verified Β·
1 Parent(s): e58ba7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -102
app.py CHANGED
@@ -1,125 +1,121 @@
1
- import streamlit as st
2
- from streamlit_chat import message
3
- from langchain.chains import ConversationalRetrievalChain
4
- from langchain.embeddings import HuggingFaceEmbeddings
5
- from langchain.llms import HuggingFacePipeline
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain.vectorstores import FAISS
8
- from langchain.memory import ConversationBufferMemory
9
- from langchain_community.document_loaders import PyPDFLoader
10
  from transformers import T5Tokenizer, T5ForConditionalGeneration
11
- import torch
12
  from transformers import pipeline
13
- import os
14
- import tempfile
15
-
 
 
 
 
 
 
 
16
 
17
- checkpoint = "LaMini-Flan-T5-783M"
18
  tokenizer = T5Tokenizer.from_pretrained(checkpoint)
19
 
20
- base_model = T5ForConditionalGeneration.from_pretrained( checkpoint, device_map = 'auto', torch_dtype = torch.float32 )
 
 
 
 
 
 
21
 
 
22
  def llm_pipeline():
23
  pipe = pipeline(
24
  'text2text-generation',
25
  model = base_model,
26
  tokenizer = tokenizer,
 
27
  do_sample = True,
28
  temperature = 0.5,
29
- max_length = 300
30
  )
31
  local_llm = HuggingFacePipeline(pipeline=pipe)
32
  return local_llm
33
 
34
-
35
- def initialize_session_state():
36
- if 'history' not in st.session_state:
37
- st.session_state['history'] = []
38
-
39
- if 'generated' not in st.session_state:
40
- st.session_state['generated'] = ["Hello! Ask me anything about πŸ€—"]
41
-
42
- if 'past' not in st.session_state:
43
- st.session_state['past'] = ["Hey! πŸ‘‹"]
44
-
45
- def conversation_chat(query, chain, history):
46
- result = chain({"question": query, "chat_history": history})
47
- history.append((query, result["answer"]))
48
- return result["answer"]
49
-
50
- def display_chat_history(chain):
51
- reply_container = st.container()
52
- container = st.container()
53
-
54
- with container:
55
- with st.form(key='my_form', clear_on_submit=True):
56
- user_input = st.text_input("Question:", placeholder="Ask about your PDF", key='input')
57
- submit_button = st.form_submit_button(label='Send')
58
-
59
- if submit_button and user_input:
60
- with st.spinner('Generating response...'):
61
- output = conversation_chat(user_input, chain, st.session_state['history'])
62
-
63
- st.session_state['past'].append(user_input)
64
- st.session_state['generated'].append(output)
65
-
66
- if st.session_state['generated']:
67
- with reply_container:
68
- for i in range(len(st.session_state['generated'])):
69
- message(st.session_state["past"][i], is_user=True, key=str(i) + '_user', avatar_style="thumbs")
70
- message(st.session_state["generated"][i], key=str(i), avatar_style="fun-emoji")
71
-
72
- def create_conversational_chain(vector_store):
73
-
74
  llm = llm_pipeline()
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
77
-
78
- chain = ConversationalRetrievalChain.from_llm(llm=llm, chain_type='stuff',
79
- retriever=vector_store.as_retriever(search_kwargs={"k": 2}),
80
- memory=memory)
81
- return chain
82
-
83
- def main():
84
- # Initialize session state
85
- initialize_session_state()
86
- st.title("Multi-PDF ChatBot using Mistral-7B-Instruct :books:")
87
- # Initialize Streamlit
88
- st.sidebar.title("Document Processing")
89
- uploaded_files = st.sidebar.file_uploader("Upload files", accept_multiple_files=True)
90
-
91
-
92
- if uploaded_files:
93
- text = []
94
- for file in uploaded_files:
95
- file_extension = os.path.splitext(file.name)[1]
96
- with tempfile.NamedTemporaryFile(delete=False) as temp_file:
97
- temp_file.write(file.read())
98
- temp_file_path = temp_file.name
99
-
100
- loader = None
101
- if file_extension == ".pdf":
102
- loader = PyPDFLoader(temp_file_path)
103
-
104
- if loader:
105
- text.extend(loader.load())
106
- os.remove(temp_file_path)
107
-
108
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=20)
109
- text_chunks = text_splitter.split_documents(text)
110
-
111
- # Create embeddings
112
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
113
- model_kwargs={'device': 'cpu'})
114
 
115
- # Create vector store
116
- vector_store = FAISS.from_documents(text_chunks, embedding=embeddings)
 
 
 
117
 
118
- # Create the chain object
119
- chain = create_conversational_chain(vector_store)
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- display_chat_history(chain)
123
-
124
- if __name__ == "__main__":
125
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
 
 
 
 
 
 
 
 
2
  from transformers import T5Tokenizer, T5ForConditionalGeneration
 
3
  from transformers import pipeline
4
+ import torch
5
+ import base64
6
+ import textwrap
7
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain.chains import RetrievalQA
10
+ from langchain_community.llms import HuggingFacePipeline
11
+ #from constants import CHROMA_SETTINGS
12
+ from streamlit_chat import message
13
+ import safetensors
14
 
15
+ checkpoint = "LaMini-Flan-T5-248M"
16
  tokenizer = T5Tokenizer.from_pretrained(checkpoint)
17
 
18
+ base_model = T5ForConditionalGeneration.from_pretrained(
19
+ checkpoint,
20
+ device_map = 'cpu',
21
+ torch_dtype = torch.float32,
22
+ offload_folder = "offload"
23
+ )
24
+
25
 
26
+ @st.cache_resource
27
  def llm_pipeline():
28
  pipe = pipeline(
29
  'text2text-generation',
30
  model = base_model,
31
  tokenizer = tokenizer,
32
+ max_length = 226,
33
  do_sample = True,
34
  temperature = 0.5,
35
+ top_p= 0.95
36
  )
37
  local_llm = HuggingFacePipeline(pipeline=pipe)
38
  return local_llm
39
 
40
+ @st.cache_resource
41
+ def qa_llm():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  llm = llm_pipeline()
43
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
44
+ db = FAISS.load_local("vector_data",embeddings)
45
+ #db = Chroma(persist_directory="db", embedding_function = embeddings, client_settings=CHROMA_SETTINGS)
46
+ retriever = db.as_retriever()
47
+ qa = RetrievalQA.from_chain_type(
48
+ llm = llm,
49
+ chain_type = "stuff",
50
+ retriever = retriever,
51
+ return_source_documents=True
52
+ )
53
+ return qa
54
 
55
+ def process_answer(instruction):
56
+ response = ''
57
+ instruction = instruction
58
+ qa = qa_llm()
59
+ generated_text = qa(instruction)
60
+ answer = generated_text['result']
61
+ return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # Display conversation history using Streamlit messages
64
+ def display_conversation(history):
65
+ for i in range(len(history["generated"])):
66
+ message(history["past"][i], is_user=True, key=str(i) + "_user")
67
+ message(history["generated"][i],key=str(i))
68
 
 
 
69
 
70
+ def main():
71
+ st.title('Chat with Your Data πŸ¦œπŸ“„')
72
+ with st.expander("About the Chatbot"):
73
+ st.markdown(
74
+ """
75
+ This is a Generative AI powered Chatbot that interacts with you and you can ask followup questions.
76
+ """
77
+ )
78
+
79
+ user_input = st.text_input("", key="input")
80
+
81
+ # Initialize session state for generated responses and past messages
82
+ if "generated" not in st.session_state:
83
+ st.session_state["generated"] = ["I am ready to help you"]
84
+ if "past" not in st.session_state:
85
+ st.session_state["past"] = ["Hey there!"]
86
 
87
+ # Search the database for a response based on user input and update session state
88
+ if user_input:
89
+ answer = process_answer({'query': user_input})
90
+ st.session_state["past"].append(user_input)
91
+ response = answer
92
+ st.session_state["generated"].append(response)
93
+
94
+ # Display conversation history using Streamlit messages
95
+ if st.session_state["generated"]:
96
+ display_conversation(st.session_state)
97
+
98
+ d = """
99
+ user_input = st.text_input("Question:", placeholder="Ask about your PDF", key='input')
100
+ with st.form(key='my_form', clear_on_submit=True):
101
+ submit_button = st.form_submit_button(label='Send')
102
+
103
+ # Initialize session state for generated responses and past messages
104
+ if "generated" not in st.session_state:
105
+ st.session_state["generated"] = ["I am ready to help you"]
106
+ if "past" not in st.session_state:
107
+ st.session_state["past"] = ["Hey there!πŸ‘‹"]
108
+
109
+
110
+ if submit_button and user_input or user_input :
111
+ st.session_state['past'].append(user_input)
112
+ with st.spinner('Generating response...'):
113
+ answer = process_answer({'query': user_input})
114
+ st.session_state['generated'].append(answer)
115
+
116
+ if st.session_state["generated"]:
117
+ display_conversation(st.session_state)"""
118
+
119
+
120
+ if __name__ == '__main__':
121
+ main()