PDF_Doc_Search / main.py
Lucas
Changinf file pdf location.
cece2a6
import os
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.document_loaders import PyPDFLoader
from langchain import PromptTemplate
from langchain.chains.summarize import load_summarize_chain
import textwrap
import panel as pn
import PyPDF2
pn.extension(notifications=True)
pn.extension('texteditor', template="bootstrap", sizing_mode='stretch_width')
pn.state.template.param.update(
main_max_width="690px",
header_background="#F08080",
)
file_input = pn.widgets.FileInput(width=300)
openaikey = pn.widgets.PasswordInput(
value="", placeholder="Entre com a OpenAI API Key aqui...", width=300
)
prompt = pn.widgets.TextEditor(
value="", placeholder="Entre com sua pergunta aqui...", height=160, toolbar=False
)
run_button = pn.widgets.Button(name="Run!")
summary_button = pn.widgets.Button(name="Resumo!")
select_k = pn.widgets.IntSlider(
name="Number of relevant chunks", start=1, end=5, step=1, value=2
)
select_chain_type = pn.widgets.RadioButtonGroup(
name='Chain type',
options=['refine', 'map_reduce', "stuff", "map_rerank"]
)
widgets = pn.Row(
pn.Column(prompt, run_button, margin=5),
pn.Card(
"Chain type:",
pn.Column(select_chain_type, select_k),
title="Advanced settings", margin=10
), width=600
)
summary_filed = pn.Row(
pn.Column(summary_button),
width=630
)
def is_valid_pdf(file_path):
try:
with open(file_path, 'rb') as f:
PyPDF2.PdfReader(f)
return True
except:
return False
def qa(file, query, chain_type, k):
# load document
if not is_valid_pdf(file):
result = {'error': 'Invalid PDF file.'}
return result
loader = PyPDFLoader(file)
documents = loader.load()
# split the documents into chunks
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
# select which embeddings we want to use
embeddings = OpenAIEmbeddings()
# create the vectorestore to use as the index
db = Chroma.from_documents(texts, embeddings)
# expose this index in a retriever interface
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": k})
# create a chain to answer questions
qa = RetrievalQA.from_chain_type(
llm=OpenAI(model_name="gpt-3.5-turbo", temperature=0), chain_type=chain_type, retriever=retriever, return_source_documents=False)
result = qa({"query": query})
print(result['result'])
return result
def summary(file):
# load document
result = {}
if not is_valid_pdf(file):
result = {'error': 'Invalid PDF file.'}
return result
loader = PyPDFLoader(file)
documents = loader.load()
combine_template = """Write a summary of the following in Portuguese in 100 words:
{text}
SUMMARY IN PORTUGUESE IN 100 WORDS:"""
COMBINE_TEMPLATE = PromptTemplate(template=combine_template, input_variables=["text"])
map_template = """Write a concise summary of the following in Portuguese in 40 words or less:
{text}
CONCISE SUMMARY IN PORTUGUESE IN 40 WORDS OR LESS:"""
MAP_TEMPLATE = PromptTemplate(template=map_template, input_variables=["text"])
chain = load_summarize_chain(OpenAI(temperature=0),
chain_type="map_reduce",
return_intermediate_steps=True,
combine_prompt=COMBINE_TEMPLATE,
map_prompt=MAP_TEMPLATE)
output_summary = chain({"input_documents": documents}, return_only_outputs=True)
result['summary'] = textwrap.fill(output_summary['output_text'],
width=100,
break_long_words=False,
replace_whitespace=False)
output_steps = output_summary['intermediate_steps']
result['steps'] = textwrap.fill('\n'.join(output_steps),
width=100,
break_long_words=False,
replace_whitespace=False)
return result
convos = [] # store all panel objects in a list
def qa_result(_):
os.environ["OPENAI_API_KEY"] = openaikey.value
if not openaikey.value:
pn.state.notifications.error('Missing API key.', duration=2000)
return pn.Column(*convos, margin=15, width=575, min_height=400)
# save pdf file to a temp file
if file_input.value is not None:
file_input.save("/.cache/temp.pdf")
prompt_text = prompt.value
if prompt_text:
result = qa(file="/.cache/temp.pdf", query=prompt_text, chain_type=select_chain_type.value,
k=select_k.value)
if result.get('error') is None:
convos.extend([
pn.Row(
pn.panel("\U0001F60A", width=10),
prompt_text,
width=600
),
pn.Row(
pn.panel("\U0001F916", width=10),
pn.Column(
result["result"],
"Fontes:",
pn.pane.Markdown(
'\n--------------------------------------------------------------------\n'.join(
doc.page_content for doc in result["source_documents"]))
)
)
])
else:
pn.state.notifications.error(result['error'], duration=2000)
else:
pn.state.notifications.error('Missing prompt.', duration=2000)
else:
pn.state.notifications.error('Missing file.', duration=2000)
return pn.Column(*convos, margin=15, width=575, min_height=400)
def summary_result(_):
os.environ["OPENAI_API_KEY"] = openaikey.value
if not openaikey.value:
pn.state.notifications.error('Missing API key.', duration=2000)
return pn.Column(*convos, margin=15, width=575, min_height=400)
# save pdf file to a temp file
if file_input.value is not None:
file_input.save("/.cache/temp.pdf")
result = summary(file="/.cache/temp.pdf")
if result.get('error') is None:
convos.extend([
pn.Row(
pn.panel("\U0001F60A", width=10),
"Resumo geral: ",
result['summary'],
width=600
),
pn.Row(
pn.panel("\U0001F916", width=10),
pn.Column(
"Resumo por página:",
result['steps']
)
)
])
else:
pn.state.notifications.error(result['error'], duration=2000)
else:
pn.state.notifications.error('Missing file.', duration=2000)
return pn.Column(*convos, margin=15, width=575, min_height=400)
qa_interactive = pn.panel(
#pn.bind(qa_result, run_button),
pn.bind(summary_result, summary_button),
loading_indicator=True,
)
output = pn.WidgetBox('*As respstas aparecerão aqui:*', qa_interactive, width=630, scroll=True)
# layout
pn.Column(
pn.pane.Markdown("""
## \U0001F4D3 Resumo de um PDF
(original implementation: @sophiamyang)
1) Suba o PDF. 2) Entre com a OpenAI API key. 3) Clique "Resumo!".
"""),
pn.Row(file_input, openaikey),
summary_filed,
output,
#widgets
).servable()