Update app.py
Browse files
app.py
CHANGED
|
@@ -1,46 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
-
import tempfile
|
| 3 |
import os
|
| 4 |
-
from
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
|
| 10 |
-
|
| 11 |
-
if 'vectorstore' not in st.session_state:
|
| 12 |
-
st.session_state.vectorstore = None
|
| 13 |
-
if 'rag_chain' not in st.session_state:
|
| 14 |
-
st.session_state.rag_chain = None
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
splits = load_and_process_pdf(tmp_file_path)
|
| 28 |
-
st.session_state.vectorstore = create_vectorstore(splits)
|
| 29 |
-
st.session_state.rag_chain = create_rag_chain()
|
| 30 |
|
| 31 |
-
st.success("PDF processed successfully! Now you can ask questions.")
|
| 32 |
|
| 33 |
-
|
| 34 |
-
os.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import streamlit as st
|
| 2 |
+
# import pandas as pd
|
| 3 |
+
# from llm import load_and_process_pdf, create_vectorstore, create_rag_chain
|
| 4 |
+
|
| 5 |
+
|
| 6 |
import streamlit as st
|
|
|
|
| 7 |
import os
|
| 8 |
+
from langchain.schema import Document
|
| 9 |
|
| 10 |
+
from langchain_community.document_loaders import PyPDFLoader, UnstructuredPDFLoader
|
| 11 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 12 |
+
from langchain_community.vectorstores import Chroma
|
| 13 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 14 |
+
from langchain.memory import ConversationBufferMemory
|
| 15 |
+
from transformers import pipeline
|
| 16 |
+
from langchain_huggingface import HuggingFacePipeline
|
| 17 |
+
from langchain_community.chat_message_histories import ChatMessageHistory
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 20 |
+
import pandas as pd
|
| 21 |
+
from langchain.prompts import ChatPromptTemplate
|
| 22 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 23 |
+
import time
|
| 24 |
+
# chromadb.api.client.SharedSystemClient.clear_system_cache()
|
| 25 |
+
import subprocess
|
| 26 |
+
from huggingface_hub import hf_hub_download
|
| 27 |
|
| 28 |
+
# Get TOKEN from environment variable
|
| 29 |
|
| 30 |
+
def process_pdf(file_path = r"chunk_metadata_template.xlsx"):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
df = pd.read_excel(file_path)
|
| 33 |
+
chunks = []
|
| 34 |
+
for i, row in df.iterrows():
|
| 35 |
+
# Create a Document object for each row, including page_content and metadata
|
| 36 |
+
chunk_with_metadata = Document(
|
| 37 |
+
page_content=row['page_content'], # Content for the chunk
|
| 38 |
+
metadata={
|
| 39 |
+
'chunk_id': row['chunk_id'], # Add chunk_id to the metadata
|
| 40 |
+
'document_title': row['document_title'], # Add document_title to the metadata
|
| 41 |
+
'topic': row['topic'],
|
| 42 |
+
'access': row['access'],# Add keywords to the metadata
|
| 43 |
+
}
|
| 44 |
+
)
|
| 45 |
+
# Append the Document object to the chunks list
|
| 46 |
+
chunks.append(chunk_with_metadata)
|
| 47 |
+
embeddings="BAAI/bge-base-en"
|
| 48 |
+
encode_kwargs = {'normalize_embeddings': True} # I.e. Cosine Similarity
|
| 49 |
|
| 50 |
+
embeddings = HuggingFaceEmbeddings(
|
| 51 |
+
model_name=embeddings,
|
| 52 |
+
model_kwargs={'device' : 'cpu' },
|
| 53 |
+
encode_kwargs=encode_kwargs
|
| 54 |
+
)
|
| 55 |
|
| 56 |
+
# return FAISS.from_documents(chunks, embedding=embeddings)
|
| 57 |
+
return Chroma.from_documents(chunks, embeddings)
|
|
|
|
|
|
|
|
|
|
| 58 |
|
|
|
|
| 59 |
|
| 60 |
+
def main():
|
| 61 |
+
TOKEN = os.environ.get('gemma2')
|
| 62 |
+
subprocess.run(["huggingface-cli", "login", "--token", TOKEN, "--add-to-git-credential"])
|
| 63 |
+
|
| 64 |
+
st.set_page_config(page_title="MBAL Chatbot", page_icon="🤖", layout="wide")
|
| 65 |
+
# Initialize session state
|
| 66 |
+
if "chat_history" not in st.session_state:
|
| 67 |
+
st.session_state.chat_history = []
|
| 68 |
+
if "vector_store" not in st.session_state:
|
| 69 |
+
st.session_state.vector_store = None
|
| 70 |
+
|
| 71 |
+
st.title("🤖 MBAL Insurance Assistant")
|
| 72 |
|
| 73 |
+
st.session_state.vector_store = process_pdf()
|
| 74 |
+
# Chat interface
|
| 75 |
+
if st.session_state.vector_store:
|
| 76 |
+
# Initialize conversation chain
|
| 77 |
|
| 78 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 79 |
+
"google/gemma-2b",
|
| 80 |
+
low_cpu_mem_usage=True,
|
| 81 |
+
torch_dtype=torch.float32
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
model_pipeline = pipeline(
|
| 88 |
+
"text-generation",
|
| 89 |
+
model=model,
|
| 90 |
+
tokenizer=tokenizer,
|
| 91 |
+
max_new_tokens=256,
|
| 92 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 93 |
+
device_map="auto"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
llm = HuggingFacePipeline(
|
| 97 |
+
pipeline=model_pipeline
|
| 98 |
+
)
|
| 99 |
+
template1 = """
|
| 100 |
+
Bạn là một AI trợ lý chuyên cung cấp thông tin cho khách hàng về sản phẩm bảo hiểm của công ty MB Ageas Life tại Việt Nam.
|
| 101 |
+
Hãy trả lời chuyên nghiệp, chính xác, cung cấp thông tin bao quát trước, các trường hợp có thể xảy ra làm ví dụ rồi mới đặt câu hỏi gợi mở nếu chưa rõ. Tất cả các thông tin cung cấp đều trong phạm vi MBAL. Những có đủ thông tin khách hàng thì mời khách hàng đăng ký để nhận tư vấn trên https://www.mbageas.life/
|
| 102 |
+
{context}
|
| 103 |
+
Câu hỏi: {question}
|
| 104 |
+
Trả lời:
|
| 105 |
+
"""
|
| 106 |
+
combined_document_chain = create_stuff_documents_chain(llm, prompt_template)
|
| 107 |
+
retriever = st.session_state.vector_store.as_retriever()
|
| 108 |
+
retrieval_chain = create_retrieval_chain(retriever, combined_document_chain)
|
| 109 |
+
# RAG_prompt = ChatPromptTemplate.from_template(template=template1)
|
| 110 |
+
|
| 111 |
+
# qa = ConversationalRetrievalChain.from_llm(
|
| 112 |
+
# llm = llm,
|
| 113 |
+
# retriever =st.session_state.vector_store.as_retriever(),
|
| 114 |
+
# combine_docs_chain_kwargs={"prompt": RAG_prompt},
|
| 115 |
+
# memory=memory,
|
| 116 |
+
# condense_question_llm = None
|
| 117 |
+
|
| 118 |
+
# )
|
| 119 |
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# Display chat history
|
| 123 |
+
for query, answer in st.session_state.chat_history:
|
| 124 |
+
with st.chat_message("user"):
|
| 125 |
+
st.write(query)
|
| 126 |
+
with st.chat_message("assistant"):
|
| 127 |
+
st.write(answer)
|
| 128 |
+
|
| 129 |
+
# # Handle new query
|
| 130 |
+
# query = st.chat_input("Ask a question about the PDF:")
|
| 131 |
+
# if query:
|
| 132 |
+
# # Add user question to history
|
| 133 |
+
# st.session_state.chat_history.append((query, ""))
|
| 134 |
+
|
| 135 |
+
# try:
|
| 136 |
+
# # Get answer
|
| 137 |
+
# result = qa({"question": query})
|
| 138 |
+
# answer = result["answer"]
|
| 139 |
+
|
| 140 |
+
# # Update chat history
|
| 141 |
+
# st.session_state.chat_history[-1] = (query, answer)
|
| 142 |
+
|
| 143 |
+
# # Rerun to update display
|
| 144 |
+
# st.rerun()
|
| 145 |
+
|
| 146 |
+
# except Exception as e:
|
| 147 |
+
# st.error(f"Error processing query: {str(e)}")
|
| 148 |
+
|
| 149 |
+
user_query = st.text_input("Enter your question here:")
|
| 150 |
+
if user_query:
|
| 151 |
+
start = time.process_time()
|
| 152 |
+
try:
|
| 153 |
+
response = retrieval_chain.invoke({"input": user_query})
|
| 154 |
+
response_time = time.process_time() - start
|
| 155 |
+
st.write(f"Response processed in {response_time:.2f} seconds.")
|
| 156 |
+
st.write(response['answer'])
|
| 157 |
+
with st.expander("View Similar Document Snippets"):
|
| 158 |
+
for i, doc in enumerate(response["context"]):
|
| 159 |
+
st.write(f"Document {i + 1}:")
|
| 160 |
+
st.write(doc.page_content)
|
| 161 |
+
st.write("--------------------------------")
|
| 162 |
+
feedback = st.radio("Was this answer helpful?", ('Yes', 'No'))
|
| 163 |
+
if feedback:
|
| 164 |
+
st.session_state.feedback = feedback
|
| 165 |
+
if feedback == 'No':
|
| 166 |
+
st.text_area("Please provide more details on how we can improve:", key='feedback_details')
|
| 167 |
+
except Exception as e:
|
| 168 |
+
st.error(f"Error during response retrieval: {e}")
|
| 169 |
+
else:
|
| 170 |
+
st.warning("LLM initialization failed or documents are not loaded. Please verify the API key and document directory.")
|
| 171 |
+
main()
|