|
|
import os |
|
|
import random |
|
|
import itertools |
|
|
import streamlit as st |
|
|
import validators |
|
|
from langchain.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, WebBaseLoader |
|
|
from langchain.vectorstores import FAISS |
|
|
from langchain.chat_models import ChatOpenAI |
|
|
from langchain.chains import QAGenerationChain |
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.callbacks import StdOutCallbackHandler |
|
|
from langchain.chains import ConversationalRetrievalChain, QAGenerationChain, LLMChain |
|
|
from langchain.memory import ConversationBufferMemory |
|
|
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT |
|
|
from langchain.chains.question_answering import load_qa_chain |
|
|
|
|
|
from langchain.prompts.chat import ( |
|
|
ChatPromptTemplate, |
|
|
SystemMessagePromptTemplate, |
|
|
AIMessagePromptTemplate, |
|
|
HumanMessagePromptTemplate, |
|
|
) |
|
|
|
|
|
st.set_page_config(page_title="DOC QA",page_icon=':book:') |
|
|
|
|
|
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer') |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def save_file_locally(file): |
|
|
'''Save uploaded files locally''' |
|
|
doc_path = os.path.join('tempdir',file.name) |
|
|
with open(doc_path,'wb') as f: |
|
|
f.write(file.getbuffer()) |
|
|
|
|
|
return doc_path |
|
|
|
|
|
@st.cache_data |
|
|
def load_prompt(): |
|
|
|
|
|
system_template="""Use only the following pieces of context to answer the users question accurately. |
|
|
Do not use any information not provided in the earnings context. |
|
|
If you don't know the answer, just say 'There is no relevant answer in the given documents', |
|
|
don't try to make up an answer. |
|
|
|
|
|
ALWAYS return a "SOURCES" part in your answer. |
|
|
The "SOURCES" part should be a reference to the source of the document from which you got your answer. |
|
|
|
|
|
Remember, do not reference any information not given in the context. |
|
|
If the answer is not available in the given context just say 'There is no relevant answer in the given document' |
|
|
|
|
|
Follow the below format when answering: |
|
|
|
|
|
Question: {question} |
|
|
SOURCES: [xyz] |
|
|
|
|
|
Begin! |
|
|
---------------- |
|
|
{context}""" |
|
|
|
|
|
messages = [ |
|
|
SystemMessagePromptTemplate.from_template(system_template), |
|
|
HumanMessagePromptTemplate.from_template("{question}") |
|
|
] |
|
|
prompt = ChatPromptTemplate.from_messages(messages) |
|
|
|
|
|
return prompt |
|
|
|
|
|
@st.cache_data |
|
|
def load_docs(files, url=False): |
|
|
|
|
|
if not url: |
|
|
|
|
|
st.info("`Reading doc ...`") |
|
|
all_text = "" |
|
|
documents = [] |
|
|
for file in files: |
|
|
file_extension = os.path.splitext(file.name)[1] |
|
|
doc_path = save_file_locally(file) |
|
|
if file_extension == ".pdf": |
|
|
|
|
|
pages = PyPDFLoader(doc_path) |
|
|
|
|
|
documents.extend(pages.load()) |
|
|
|
|
|
elif file_extension == ".txt": |
|
|
|
|
|
pages = TextLoader(doc_path) |
|
|
documents.extend(pages.load()) |
|
|
|
|
|
elif file_extension == ".docx": |
|
|
|
|
|
pages = Docx2txtLoader(doc_path) |
|
|
documents.extend(pages.load()) |
|
|
|
|
|
else: |
|
|
st.warning('Please provide txt or pdf or docx.', icon="⚠️") |
|
|
|
|
|
elif url: |
|
|
|
|
|
st.info("`Reading web link ...`") |
|
|
|
|
|
loader = WebBaseLoader(files) |
|
|
|
|
|
documents = loader.load() |
|
|
|
|
|
return ','.join([doc.page_content for doc in documents]) |
|
|
|
|
|
bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2", |
|
|
'instructor-large': 'hkunlp/instructor-large'} |
|
|
|
|
|
@st.cache_data |
|
|
def gen_embeddings(model_name): |
|
|
|
|
|
'''Generate embeddings for given model''' |
|
|
|
|
|
if model_name == 'mpnet-base-v2': |
|
|
embeddings = HuggingFaceEmbeddings(model_name=bi_enc_dict[model_name]) |
|
|
|
|
|
elif model_name == 'instructor-large': |
|
|
|
|
|
embeddings = HuggingFaceInstructEmbeddings(model_name=bi_enc_dict[model_name], |
|
|
query_instruction='Represent the question for retrieving supporting paragraphs: ', |
|
|
embed_instruction='Represent the paragraph for retrieval: ') |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def load_retrieval_chain(vectorstore): |
|
|
|
|
|
'''Load Chain''' |
|
|
|
|
|
|
|
|
callback_handler = [StdOutCallbackHandler()] |
|
|
|
|
|
chat_llm = ChatOpenAI(streaming=True, |
|
|
model_name = 'gpt-4', |
|
|
callbacks=callback_handler, |
|
|
verbose=True, |
|
|
temperature=0 |
|
|
) |
|
|
question_generator = LLMChain(llm=chat_llm, prompt=CONDENSE_QUESTION_PROMPT) |
|
|
doc_chain = load_qa_chain(llm=chat_llm,chain_type="stuff",prompt=load_prompt()) |
|
|
chain = ConversationalRetrievalChain(retriever=vectorstore.as_retriever(search_kwags={"k": 3}), |
|
|
question_generator=question_generator, |
|
|
combine_docs_chain=doc_chain, |
|
|
memory=memory, |
|
|
return_source_documents=True, |
|
|
get_chat_history=lambda h :h) |
|
|
|
|
|
return chain |
|
|
|
|
|
@st.cache_resource |
|
|
def process_corpus(corpus,model_name, chunk_size=1000, overlap=50): |
|
|
|
|
|
'''Process text for Semantic Search''' |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=overlap) |
|
|
|
|
|
texts = text_splitter.split_text(corpus) |
|
|
|
|
|
|
|
|
num_chunks = len(texts) |
|
|
st.write(f"Number of text chunks: {num_chunks}") |
|
|
|
|
|
embeddings = gen_embeddings(model_name) |
|
|
|
|
|
vectorstore = FAISS.from_texts(texts, embeddings) |
|
|
|
|
|
chain = load_retrieval_chain(vectorstore) |
|
|
|
|
|
return chain |
|
|
|
|
|
@st.cache_data |
|
|
def run_qa_chain(text,query,model_name): |
|
|
'''Run the QnA chain''' |
|
|
|
|
|
chain = process_corpus(text,model_name) |
|
|
|
|
|
answer = chain({"question": query}) |
|
|
|
|
|
return answer |
|
|
|
|
|
@st.cache_resource |
|
|
def gen_qa_response(text,model_name,user_question): |
|
|
'''Generate responses from query''' |
|
|
|
|
|
if user_question: |
|
|
result = run_qa_chain(text,user_question,model_name) |
|
|
|
|
|
references = [doc.page_content for doc in result['source_documents']] |
|
|
answer = result['answer'] |
|
|
|
|
|
with st.expander(label='Query Result', expanded=True): |
|
|
st.write(answer) |
|
|
|
|
|
with st.expander(label='References from Corpus used to Generate Result'): |
|
|
for ref in references: |
|
|
st.write(ref) |
|
|
|
|
|
|
|
|
|
|
|
if 'eval_set' not in st.session_state: |
|
|
|
|
|
num_eval_questions = 10 |
|
|
st.session_state.eval_set = generate_eval(text, num_eval_questions, 3000) |
|
|
|
|
|
|
|
|
|
|
|
for i, qa_pair in enumerate(st.session_state.eval_set): |
|
|
st.sidebar.markdown( |
|
|
f""" |
|
|
<div class="css-card"> |
|
|
<span class="card-tag">Question {i + 1}</span> |
|
|
<p style="font-size: 12px;">{qa_pair['question']}</p> |
|
|
<p style="font-size: 12px;">{qa_pair['answer']}</p> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
st.write("Ready to answer questions.") |
|
|
|
|
|
@st.cache_data |
|
|
def generate_eval(raw_text, N, chunk): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
update = st.empty() |
|
|
ques_update = st.empty() |
|
|
update.info("`Generating sample questions ...`") |
|
|
n = len(raw_text) |
|
|
starting_indices = [random.randint(0, n-chunk) for _ in range(N)] |
|
|
sub_sequences = [raw_text[i:i+chunk] for i in starting_indices] |
|
|
chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0,model_name='gpt-4')) |
|
|
eval_set = [] |
|
|
for i, b in enumerate(sub_sequences): |
|
|
try: |
|
|
|
|
|
qa = chain.run(b) |
|
|
eval_set.append(qa) |
|
|
ques_update.info(f"Creating Question: {i+1}") |
|
|
|
|
|
except: |
|
|
st.warning(f'Error in generating Question: {i+1}...', icon="⚠️") |
|
|
continue |
|
|
|
|
|
eval_set_full = list(itertools.chain.from_iterable(eval_set)) |
|
|
|
|
|
update.empty() |
|
|
ques_update.empty() |
|
|
|
|
|
return eval_set_full |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
|
|
|
#MainMenu {visibility: hidden; |
|
|
# } |
|
|
footer {visibility: hidden; |
|
|
} |
|
|
.css-card { |
|
|
border-radius: 0px; |
|
|
padding: 30px 10px 10px 10px; |
|
|
background-color: black; |
|
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
|
|
margin-bottom: 10px; |
|
|
font-family: "IBM Plex Sans", sans-serif; |
|
|
} |
|
|
|
|
|
.card-tag { |
|
|
border-radius: 0px; |
|
|
padding: 1px 5px 1px 5px; |
|
|
margin-bottom: 10px; |
|
|
position: absolute; |
|
|
left: 0px; |
|
|
top: 0px; |
|
|
font-size: 0.6rem; |
|
|
font-family: "IBM Plex Sans", sans-serif; |
|
|
color: white; |
|
|
background-color: green; |
|
|
} |
|
|
|
|
|
.css-zt5igj {left:0; |
|
|
} |
|
|
|
|
|
span.css-10trblm {margin-left:0; |
|
|
} |
|
|
|
|
|
div.css-1kyxreq {margin-top: -40px; |
|
|
} |
|
|
|
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
st.sidebar.image("img/logo.jpg") |
|
|
|
|
|
|
|
|
st.write( |
|
|
f""" |
|
|
<div style="display: flex; align-items: center; margin-left: 0;"> |
|
|
<h1 style="display: inline-block;">DOC GPT</h1> |
|
|
<sup style="margin-left:5px;font-size:small; color: green;">beta</sup> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
|
|
|
st.sidebar.title("Menu") |
|
|
|
|
|
|
|
|
splitter_type = "RecursiveCharacterTextSplitter" |
|
|
|
|
|
uploaded_files = st.file_uploader("Upload a PDF or TXT or DOCX Document", type=[ |
|
|
"pdf", "txt", "docx"], accept_multiple_files=True) |
|
|
|
|
|
st.markdown( |
|
|
"<h3 style='text-align: center; color: red;'>OR</h3>", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
url_text = st.text_input("Please Enter a url here for an html file you would like to load..") |
|
|
|
|
|
bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2", |
|
|
'instructor-base': 'hkunlp/instructor-base'} |
|
|
|
|
|
|
|
|
model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox') |
|
|
|
|
|
if uploaded_files: |
|
|
|
|
|
if 'last_uploaded_files' not in st.session_state or st.session_state.last_uploaded_files != uploaded_files: |
|
|
st.session_state.last_uploaded_files = uploaded_files |
|
|
if 'eval_set' in st.session_state: |
|
|
del st.session_state['eval_set'] |
|
|
|
|
|
|
|
|
raw_text = load_docs(uploaded_files) |
|
|
st.success("Documents uploaded and processed.") |
|
|
|
|
|
|
|
|
user_question = st.text_input("Enter your question:") |
|
|
|
|
|
gen_qa_response(raw_text,model_name, user_question) |
|
|
|
|
|
elif url_text and validators.url(url_text): |
|
|
|
|
|
|
|
|
if 'url_files' not in st.session_state or st.session_state.url_files != url_text: |
|
|
st.session_state.url_files = url_text |
|
|
if 'eval_set' in st.session_state: |
|
|
del st.session_state['eval_set'] |
|
|
|
|
|
|
|
|
|
|
|
loaded_docs = load_docs(url_text,url=True) |
|
|
st.success("Web Document uploaded and processed.") |
|
|
|
|
|
gen_qa_response(loaded_docs,model_name) |
|
|
|
|
|
|
|
|
st.markdown("") |