userlele commited on
Commit
58da6d2
·
verified ·
1 Parent(s): 26af1d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -33
app.py CHANGED
@@ -1,46 +1,171 @@
 
 
 
 
 
1
  import streamlit as st
2
- import tempfile
3
  import os
4
- from llm import load_and_process_pdf, create_vectorstore, create_rag_chain, get_response
5
 
6
- st.set_page_config(page_title="PDF Q&A Chatbot", page_icon="📚")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- st.title("PDF Q&A Chatbot")
9
 
10
- # Initialize session state for vector store and chain
11
- if 'vectorstore' not in st.session_state:
12
- st.session_state.vectorstore = None
13
- if 'rag_chain' not in st.session_state:
14
- st.session_state.rag_chain = None
15
 
16
- # File uploader
17
- uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- if uploaded_file is not None and st.session_state.vectorstore is None:
20
- # Save the uploaded file temporarily
21
- with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
22
- tmp_file.write(uploaded_file.getvalue())
23
- tmp_file_path = tmp_file.name
24
 
25
- # Process the PDF only once
26
- with st.spinner("Processing PDF..."):
27
- splits = load_and_process_pdf(tmp_file_path)
28
- st.session_state.vectorstore = create_vectorstore(splits)
29
- st.session_state.rag_chain = create_rag_chain()
30
 
31
- st.success("PDF processed successfully! Now you can ask questions.")
32
 
33
- # Clean up the temporary file
34
- os.unlink(tmp_file_path)
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Question input
37
- if st.session_state.vectorstore is not None:
38
- question = st.text_input("Ask a question about the PDF:")
 
39
 
40
- if question:
41
- with st.spinner("Generating answer..."):
42
- answer = get_response(st.session_state.rag_chain, st.session_state.vectorstore, question)
43
- st.write("Answer:", answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- else:
46
- st.info("Please upload a PDF file to get started.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
+ # import pandas as pd
3
+ # from llm import load_and_process_pdf, create_vectorstore, create_rag_chain
4
+
5
+
6
  import streamlit as st
 
7
  import os
8
+ from langchain.schema import Document
9
 
10
+ from langchain_community.document_loaders import PyPDFLoader, UnstructuredPDFLoader
11
+ from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain_community.vectorstores import Chroma
13
+ from langchain.chains import ConversationalRetrievalChain
14
+ from langchain.memory import ConversationBufferMemory
15
+ from transformers import pipeline
16
+ from langchain_huggingface import HuggingFacePipeline
17
+ from langchain_community.chat_message_histories import ChatMessageHistory
18
+ import torch
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer
20
+ import pandas as pd
21
+ from langchain.prompts import ChatPromptTemplate
22
+ from langchain.chains.combine_documents import create_stuff_documents_chain
23
+ import time
24
+ # chromadb.api.client.SharedSystemClient.clear_system_cache()
25
+ import subprocess
26
+ from huggingface_hub import hf_hub_download
27
 
28
+ # Get TOKEN from environment variable
29
 
30
+ def process_pdf(file_path = r"chunk_metadata_template.xlsx"):
 
 
 
 
31
 
32
+ df = pd.read_excel(file_path)
33
+ chunks = []
34
+ for i, row in df.iterrows():
35
+ # Create a Document object for each row, including page_content and metadata
36
+ chunk_with_metadata = Document(
37
+ page_content=row['page_content'], # Content for the chunk
38
+ metadata={
39
+ 'chunk_id': row['chunk_id'], # Add chunk_id to the metadata
40
+ 'document_title': row['document_title'], # Add document_title to the metadata
41
+ 'topic': row['topic'],
42
+ 'access': row['access'],# Add keywords to the metadata
43
+ }
44
+ )
45
+ # Append the Document object to the chunks list
46
+ chunks.append(chunk_with_metadata)
47
+ embeddings="BAAI/bge-base-en"
48
+ encode_kwargs = {'normalize_embeddings': True} # I.e. Cosine Similarity
49
 
50
+ embeddings = HuggingFaceEmbeddings(
51
+ model_name=embeddings,
52
+ model_kwargs={'device' : 'cpu' },
53
+ encode_kwargs=encode_kwargs
54
+ )
55
 
56
+ # return FAISS.from_documents(chunks, embedding=embeddings)
57
+ return Chroma.from_documents(chunks, embeddings)
 
 
 
58
 
 
59
 
60
+ def main():
61
+ TOKEN = os.environ.get('gemma2')
62
+ subprocess.run(["huggingface-cli", "login", "--token", TOKEN, "--add-to-git-credential"])
63
+
64
+ st.set_page_config(page_title="MBAL Chatbot", page_icon="🤖", layout="wide")
65
+ # Initialize session state
66
+ if "chat_history" not in st.session_state:
67
+ st.session_state.chat_history = []
68
+ if "vector_store" not in st.session_state:
69
+ st.session_state.vector_store = None
70
+
71
+ st.title("🤖 MBAL Insurance Assistant")
72
 
73
+ st.session_state.vector_store = process_pdf()
74
+ # Chat interface
75
+ if st.session_state.vector_store:
76
+ # Initialize conversation chain
77
 
78
+ model = AutoModelForCausalLM.from_pretrained(
79
+ "google/gemma-2b",
80
+ low_cpu_mem_usage=True,
81
+ torch_dtype=torch.float32
82
+ )
83
+
84
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
85
+
86
+
87
+ model_pipeline = pipeline(
88
+ "text-generation",
89
+ model=model,
90
+ tokenizer=tokenizer,
91
+ max_new_tokens=256,
92
+ pad_token_id=tokenizer.eos_token_id,
93
+ device_map="auto"
94
+ )
95
+
96
+ llm = HuggingFacePipeline(
97
+ pipeline=model_pipeline
98
+ )
99
+ template1 = """
100
+ Bạn là một AI trợ lý chuyên cung cấp thông tin cho khách hàng về sản phẩm bảo hiểm của công ty MB Ageas Life tại Việt Nam.
101
+ Hãy trả lời chuyên nghiệp, chính xác, cung cấp thông tin bao quát trước, các trường hợp có thể xảy ra làm ví dụ rồi mới đặt câu hỏi gợi mở nếu chưa rõ. Tất cả các thông tin cung cấp đều trong phạm vi MBAL. Những có đủ thông tin khách hàng thì mời khách hàng đăng ký để nhận tư vấn trên https://www.mbageas.life/
102
+ {context}
103
+ Câu hỏi: {question}
104
+ Trả lời:
105
+ """
106
+ combined_document_chain = create_stuff_documents_chain(llm, prompt_template)
107
+ retriever = st.session_state.vector_store.as_retriever()
108
+ retrieval_chain = create_retrieval_chain(retriever, combined_document_chain)
109
+ # RAG_prompt = ChatPromptTemplate.from_template(template=template1)
110
+
111
+ # qa = ConversationalRetrievalChain.from_llm(
112
+ # llm = llm,
113
+ # retriever =st.session_state.vector_store.as_retriever(),
114
+ # combine_docs_chain_kwargs={"prompt": RAG_prompt},
115
+ # memory=memory,
116
+ # condense_question_llm = None
117
+
118
+ # )
119
 
120
+
121
+
122
+ # Display chat history
123
+ for query, answer in st.session_state.chat_history:
124
+ with st.chat_message("user"):
125
+ st.write(query)
126
+ with st.chat_message("assistant"):
127
+ st.write(answer)
128
+
129
+ # # Handle new query
130
+ # query = st.chat_input("Ask a question about the PDF:")
131
+ # if query:
132
+ # # Add user question to history
133
+ # st.session_state.chat_history.append((query, ""))
134
+
135
+ # try:
136
+ # # Get answer
137
+ # result = qa({"question": query})
138
+ # answer = result["answer"]
139
+
140
+ # # Update chat history
141
+ # st.session_state.chat_history[-1] = (query, answer)
142
+
143
+ # # Rerun to update display
144
+ # st.rerun()
145
+
146
+ # except Exception as e:
147
+ # st.error(f"Error processing query: {str(e)}")
148
+
149
+ user_query = st.text_input("Enter your question here:")
150
+ if user_query:
151
+ start = time.process_time()
152
+ try:
153
+ response = retrieval_chain.invoke({"input": user_query})
154
+ response_time = time.process_time() - start
155
+ st.write(f"Response processed in {response_time:.2f} seconds.")
156
+ st.write(response['answer'])
157
+ with st.expander("View Similar Document Snippets"):
158
+ for i, doc in enumerate(response["context"]):
159
+ st.write(f"Document {i + 1}:")
160
+ st.write(doc.page_content)
161
+ st.write("--------------------------------")
162
+ feedback = st.radio("Was this answer helpful?", ('Yes', 'No'))
163
+ if feedback:
164
+ st.session_state.feedback = feedback
165
+ if feedback == 'No':
166
+ st.text_area("Please provide more details on how we can improve:", key='feedback_details')
167
+ except Exception as e:
168
+ st.error(f"Error during response retrieval: {e}")
169
+ else:
170
+ st.warning("LLM initialization failed or documents are not loaded. Please verify the API key and document directory.")
171
+ main()