EddyGiusepe commited on
Commit
0b866d1
·
1 Parent(s): e7fbeab

Corrigir ainda ...

Browse files
Files changed (1) hide show
  1. 4.1_RAG_chroma_pdf_qa.py +105 -0
4.1_RAG_chroma_pdf_qa.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Scientist.: Dr. Eddy Giusepe Chirinos Isidro
3
+
4
+ Link de estudo --> https://sophiamyang.medium.com/building-a-retrieval-augmented-generation-chatbot-d567a24fcd14
5
+ """
6
+ import os
7
+ import tempfile
8
+
9
+ import panel as pn
10
+ from langchain.chains import RetrievalQA
11
+ from langchain.document_loaders import PyPDFLoader
12
+ from langchain.embeddings import OpenAIEmbeddings
13
+ from langchain.llms import OpenAI
14
+ from langchain.text_splitter import CharacterTextSplitter
15
+ from langchain.vectorstores import Chroma
16
+
17
+
18
+ import panel as pn
19
+ from panel.chat import ChatInterface
20
+ pn.extension("perspective")
21
+
22
+
23
+ def initialize_chain():
24
+ if key_input.value:
25
+ os.environ["OPENAI_API_KEY"] = key_input.value
26
+
27
+ selections = (pdf_input.value, k_slider.value, chain_select.value)
28
+ if selections in pn.state.cache:
29
+ return pn.state.cache[selections]
30
+
31
+ chat_input.placeholder = "Ask questions here!"
32
+
33
+ # load document
34
+ with tempfile.NamedTemporaryFile("wb", delete=False) as f:
35
+ f.write(pdf_input.value)
36
+ file_name = f.name
37
+ loader = PyPDFLoader(file_name)
38
+ documents = loader.load()
39
+ # split the documents into chunks
40
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
41
+ texts = text_splitter.split_documents(documents)
42
+ # select which embeddings we want to use
43
+ embeddings = OpenAIEmbeddings()
44
+ # create the vectorestore to use as the index
45
+ db = Chroma.from_documents(texts, embeddings)
46
+ # expose this index in a retriever interface
47
+ retriever = db.as_retriever(
48
+ search_type="similarity", search_kwargs={"k": k_slider.value}
49
+ )
50
+ # create a chain to answer questions
51
+ qa = RetrievalQA.from_chain_type(
52
+ llm=OpenAI(),
53
+ chain_type=chain_select.value,
54
+ retriever=retriever,
55
+ return_source_documents=True,
56
+ verbose=True,
57
+ )
58
+ return qa
59
+
60
+
61
+ async def respond(contents, user, chat_interface):
62
+ if not pdf_input.value:
63
+ chat_interface.send(
64
+ {"user": "System", "value": "Please first upload a PDF!"}, respond=False
65
+ )
66
+ return
67
+ elif chat_interface.active == 0:
68
+ chat_interface.active = 1
69
+ chat_interface.active_widget.placeholder = "Ask questions here!"
70
+ yield {"user": "OpenAI", "value": "Let's chat about the PDF!"}
71
+ return
72
+
73
+ qa = initialize_chain()
74
+ response = qa({"query": contents})
75
+ answers = pn.Column(response["result"])
76
+ answers.append(pn.layout.Divider())
77
+ for doc in response["source_documents"][::-1]:
78
+ answers.append(f"**Page {doc.metadata['page']}**:")
79
+ answers.append(f"```\n{doc.page_content}\n```")
80
+ yield {"user": "OpenAI", "value": answers}
81
+
82
+
83
+ pdf_input = pn.widgets.FileInput(accept=".pdf", value="", height=50)
84
+ key_input = pn.widgets.PasswordInput(
85
+ name="OpenAI Key",
86
+ placeholder="sk-...",
87
+ )
88
+ k_slider = pn.widgets.IntSlider(
89
+ name="Number of Relevant Chunks", start=1, end=5, step=1, value=2
90
+ )
91
+ chain_select = pn.widgets.RadioButtonGroup(
92
+ name="Chain Type", options=["stuff", "map_reduce", "refine", "map_rerank"]
93
+ )
94
+ chat_input = pn.widgets.TextInput(placeholder="First, upload a PDF!")
95
+ chat_interface = pn.chat.ChatInterface(
96
+ callback=respond, sizing_mode="stretch_width", widgets=[pdf_input, chat_input]
97
+ )
98
+ chat_interface.send(
99
+ {"user": "System", "value": "Please first upload a PDF and click send!"},
100
+ respond=False,
101
+ )
102
+ template = pn.template.BootstrapTemplate(
103
+ sidebar=[key_input, k_slider, chain_select], main=[chat_interface]
104
+ )
105
+ template.servable()