import os import random import itertools import streamlit as st import validators from langchain.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, WebBaseLoader from langchain.vectorstores import FAISS from langchain.chat_models import ChatOpenAI from langchain.chains import QAGenerationChain from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.callbacks import StdOutCallbackHandler from langchain.chains import ConversationalRetrievalChain, QAGenerationChain, LLMChain from langchain.memory import ConversationBufferMemory from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.question_answering import load_qa_chain from langchain.prompts.chat import ( ChatPromptTemplate, SystemMessagePromptTemplate, AIMessagePromptTemplate, HumanMessagePromptTemplate, ) st.set_page_config(page_title="DOC QA",page_icon=':book:') memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer') @st.cache_data def save_file_locally(file): '''Save uploaded files locally''' doc_path = os.path.join('tempdir',file.name) with open(doc_path,'wb') as f: f.write(file.getbuffer()) return doc_path @st.cache_data def load_prompt(): system_template="""Use only the following pieces of context to answer the users question accurately. Do not use any information not provided in the earnings context. If you don't know the answer, just say 'There is no relevant answer in the given documents', don't try to make up an answer. ALWAYS return a "SOURCES" part in your answer. The "SOURCES" part should be a reference to the source of the document from which you got your answer. Remember, do not reference any information not given in the context. If the answer is not available in the given context just say 'There is no relevant answer in the given document' Follow the below format when answering: Question: {question} SOURCES: [xyz] Begin! ---------------- {context}""" messages = [ SystemMessagePromptTemplate.from_template(system_template), HumanMessagePromptTemplate.from_template("{question}") ] prompt = ChatPromptTemplate.from_messages(messages) return prompt @st.cache_data def load_docs(files, url=False): if not url: st.info("`Reading doc ...`") all_text = "" documents = [] for file in files: file_extension = os.path.splitext(file.name)[1] doc_path = save_file_locally(file) if file_extension == ".pdf": pages = PyPDFLoader(doc_path) documents.extend(pages.load()) elif file_extension == ".txt": #stringio = StringIO(file_path.getvalue().decode("utf-8")) pages = TextLoader(doc_path) documents.extend(pages.load()) elif file_extension == ".docx": #stringio = StringIO(file_path.getvalue().decode("utf-8")) pages = Docx2txtLoader(doc_path) documents.extend(pages.load()) else: st.warning('Please provide txt or pdf or docx.', icon="⚠️") elif url: st.info("`Reading web link ...`") loader = WebBaseLoader(files) documents = loader.load() return ','.join([doc.page_content for doc in documents]) bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2", 'instructor-large': 'hkunlp/instructor-large'} @st.cache_data def gen_embeddings(model_name): '''Generate embeddings for given model''' if model_name == 'mpnet-base-v2': embeddings = HuggingFaceEmbeddings(model_name=bi_enc_dict[model_name]) elif model_name == 'instructor-large': embeddings = HuggingFaceInstructEmbeddings(model_name=bi_enc_dict[model_name], query_instruction='Represent the question for retrieving supporting paragraphs: ', embed_instruction='Represent the paragraph for retrieval: ') return embeddings def load_retrieval_chain(vectorstore): '''Load Chain''' # Initialize the RetrievalQA chain with streaming output callback_handler = [StdOutCallbackHandler()] chat_llm = ChatOpenAI(streaming=True, model_name = 'gpt-4', callbacks=callback_handler, verbose=True, temperature=0 ) question_generator = LLMChain(llm=chat_llm, prompt=CONDENSE_QUESTION_PROMPT) doc_chain = load_qa_chain(llm=chat_llm,chain_type="stuff",prompt=load_prompt()) chain = ConversationalRetrievalChain(retriever=vectorstore.as_retriever(search_kwags={"k": 3}), question_generator=question_generator, combine_docs_chain=doc_chain, memory=memory, return_source_documents=True, get_chat_history=lambda h :h) return chain @st.cache_resource def process_corpus(corpus,model_name, chunk_size=1000, overlap=50): '''Process text for Semantic Search''' text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=overlap) texts = text_splitter.split_text(corpus) # Display the number of text chunks num_chunks = len(texts) st.write(f"Number of text chunks: {num_chunks}") embeddings = gen_embeddings(model_name) vectorstore = FAISS.from_texts(texts, embeddings) chain = load_retrieval_chain(vectorstore) return chain @st.cache_data def run_qa_chain(text,query,model_name): '''Run the QnA chain''' chain = process_corpus(text,model_name) answer = chain({"question": query}) return answer @st.cache_resource def gen_qa_response(text,model_name,user_question): '''Generate responses from query''' if user_question: result = run_qa_chain(text,user_question,model_name) references = [doc.page_content for doc in result['source_documents']] answer = result['answer'] with st.expander(label='Query Result', expanded=True): st.write(answer) with st.expander(label='References from Corpus used to Generate Result'): for ref in references: st.write(ref) # Check if there are no generated question-answer pairs in the session state if 'eval_set' not in st.session_state: # Use the generate_eval function to generate question-answer pairs num_eval_questions = 10 # Number of question-answer pairs to generate st.session_state.eval_set = generate_eval(text, num_eval_questions, 3000) # Display the question-answer pairs in the sidebar with smaller text for i, qa_pair in enumerate(st.session_state.eval_set): st.sidebar.markdown( f"""
{qa_pair['question']}
{qa_pair['answer']}