| import os | |
| import tempfile | |
| import streamlit as st | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.messages import AIMessage, HumanMessage | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_openai.chat_models import ChatOpenAI | |
| from langchain_openai.embeddings import OpenAIEmbeddings | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from utils import FileParser | |
| vector_database_name = "rag-poc" | |
| temp_pdf_folder = "temp-files" | |
| vector_database_path = ( | |
| f"{os.environ.get('VECTOR_DATABASE_PATH', '.')}/{vector_database_name}" | |
| ) | |
| RETRIEVER = None | |
| def load_and_split(file, ocr_enabled): | |
| with tempfile.TemporaryDirectory() as temp_pdf_folder: | |
| local_filepath = os.path.join(temp_pdf_folder, file.name) | |
| with open(local_filepath, "wb") as f: | |
| f.write(file.getvalue()) | |
| text = FileParser().parse(input_dir=temp_pdf_folder, ocr_enabled=ocr_enabled) | |
| docs = [] | |
| if text: | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=512, chunk_overlap=100 | |
| ) | |
| texts = text_splitter.split_text(text) | |
| docs = text_splitter.create_documents( | |
| texts=texts, metadatas=[{"file_name": file.name}] * len(texts) | |
| ) | |
| return docs | |
| def initialize_vector_db(): | |
| vector_database = FAISS.from_texts([""], OpenAIEmbeddings()) | |
| vector_database.save_local(vector_database_path) | |
| return vector_database | |
| def load_vector_db(): | |
| if os.path.exists(vector_database_path): | |
| return FAISS.load_local( | |
| vector_database_path, | |
| OpenAIEmbeddings(), | |
| allow_dangerous_deserialization=True, | |
| ) | |
| return initialize_vector_db() | |
| def append_to_vector_db(docs: list = []): | |
| global RETRIEVER | |
| existing_vector_db = load_vector_db() | |
| new_vector_db = FAISS.from_documents(docs, OpenAIEmbeddings()) | |
| existing_vector_db.merge_from(new_vector_db) | |
| existing_vector_db.save_local(vector_database_path) | |
| RETRIEVER = existing_vector_db.as_retriever() | |
| def create_embeddings(files: list = [], ocr_enabled: bool = False): | |
| for file in files: | |
| docs = load_and_split(file=file, ocr_enabled=ocr_enabled) | |
| if docs: | |
| append_to_vector_db(docs=docs) | |
| st.session_state.last_uploaded_files.append(file.name) | |
| st.toast(f"{file.name} processed successfully") | |
| print(f"{file.name} processed successfully") | |
| else: | |
| st.toast(f"{file.name} could not be processed") | |
| print(f"{file.name} could not be processed") | |
| def get_response(user_query, chat_history): | |
| docs = RETRIEVER.invoke(user_query, k=20) | |
| additional_info = RETRIEVER.invoke( | |
| " ".join( | |
| [ | |
| message.content | |
| for message in chat_history | |
| if isinstance(message, HumanMessage) | |
| ] | |
| ), | |
| k=20, | |
| ) | |
| docs_content = [doc.page_content for doc in docs] | |
| for doc in additional_info: | |
| if doc.page_content not in docs_content: | |
| docs.append(doc) | |
| template = """ | |
| You are Sifa, a virtual assistant designed by Sifars. | |
| Execute the below mandatory considerations when responding to the inquiries: | |
| --- Tone - Respectful, Patient, and Encouraging: | |
| Maintain a tone that is not only polite but also encouraging. Positive language can help build confidence, especially when they are trying to learn something new. | |
| Be mindful of cultural references or idioms that may not be universally understood or may date back to a different era, ensuring relatability. | |
| --- Clarity - Simple, Direct, and Unambiguous: | |
| Avoid abbreviations, slang, or colloquialisms that might be confusing. Stick to standard language. | |
| Use bullet points or numbered lists to break down instructions or information, which can aid in comprehension. | |
| --- Structure - Organized, Consistent, and Considerate: | |
| Include relevant examples or analogies that relate to experiences common in their lifetime, which can aid in understanding complex topics. | |
| --- Empathy and Understanding - Compassionate and Responsive: | |
| Recognize and validate their feelings or concerns. Phrases like, “It’s completely normal to find this challenging,” can be comforting. | |
| Be aware of the potential need for more frequent repetition or rephrasing of information for clarity. | |
| Answer the following questions considering the documents and/or history of the conversation. | |
| Chat history: {chat_history} | |
| Documents from files: {retrieved_info} | |
| User question: {user_question} | |
| """ | |
| prompt = ChatPromptTemplate.from_template(template) | |
| llm = ChatOpenAI(model="gpt-4o", streaming=True) | |
| chain = prompt | llm | StrOutputParser() | |
| return chain.stream( | |
| { | |
| "chat_history": chat_history, | |
| "retrieved_info": docs, | |
| "user_question": user_query, | |
| } | |
| ) | |
| def main(): | |
| st.set_page_config(page_title="RAG POC", page_icon="") | |
| st.title("RAG POC") | |
| if "last_uploaded_files" not in st.session_state: | |
| st.session_state.last_uploaded_files = [] | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [ | |
| AIMessage(content="Hello, I am Sifa. How can I help you?"), | |
| ] | |
| for message in st.session_state.chat_history: | |
| if isinstance(message, AIMessage): | |
| with st.chat_message("AI"): | |
| st.write(message.content) | |
| elif isinstance(message, HumanMessage): | |
| with st.chat_message("Human"): | |
| st.write(message.content) | |
| user_query = st.chat_input("Type your message here...") | |
| if user_query is not None and user_query != "": | |
| st.session_state.chat_history.append(HumanMessage(content=user_query)) | |
| with st.chat_message("Human"): | |
| st.markdown(user_query) | |
| with st.chat_message("AI"): | |
| response = st.write_stream( | |
| get_response( | |
| user_query=user_query, chat_history=st.session_state.chat_history | |
| ) | |
| ) | |
| st.session_state.chat_history.append(AIMessage(content=response)) | |
| uploaded_files = st.sidebar.file_uploader( | |
| label="Upload files", accept_multiple_files=True | |
| ) | |
| ocr_enabled = st.sidebar.checkbox("Enable OCR", value=False) | |
| to_be_vectorised_files = [ | |
| item | |
| for item in uploaded_files | |
| if item.name not in st.session_state.last_uploaded_files | |
| ] | |
| if to_be_vectorised_files: | |
| create_embeddings(files=to_be_vectorised_files, ocr_enabled=ocr_enabled) | |
| if __name__ == "__main__": | |
| RETRIEVER = load_vector_db().as_retriever() | |
| main() | |