Entreprenerdly commited on
Commit
5ece9c6
·
verified ·
1 Parent(s): 257441b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -0
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ import PyPDF2
4
+ from io import BytesIO
5
+
6
+ from langchain_openai import OpenAIEmbeddings
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.vectorstores import Chroma
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from langchain_openai import ChatOpenAI
11
+
12
+ from langchain.docstore.document import Document
13
+ from langchain.memory import ChatMessageHistory, ConversationBufferMemory
14
+
15
+ import chainlit as cl
16
+
17
+ open_aikey = os.getenv('openaikey')
18
+
19
+ OPENAI_API_KEY = open_aikey
20
+
21
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
22
+
23
+ template_questions = [
24
+ "1. What is this document about?",
25
+ "2. What company is this document about?",
26
+ ]
27
+
28
+ @cl.on_chat_start
29
+ async def on_chat_start():
30
+ files = None
31
+
32
+ # Wait for the user to upload a file
33
+ while files is None:
34
+ files = await cl.AskFileMessage(
35
+ content="Please upload a text or PDF file to begin!",
36
+ accept=["text/plain", "application/pdf"],
37
+ max_size_mb=20,
38
+ timeout=180,
39
+ ).send()
40
+
41
+ file = files[0]
42
+
43
+ msg = cl.Message(content=f"Processing `{file.name}`...")
44
+ await msg.send()
45
+
46
+ # Check file type and read accordingly
47
+ if file.type == "text/plain":
48
+ with open(file.path, "r", encoding="utf-8") as f:
49
+ text = f.read()
50
+ elif file.type == "application/pdf":
51
+ # Read PDF from a temporary path given by file.path
52
+ with open(file.path, "rb") as f:
53
+ pdf_reader = PyPDF2.PdfReader(f)
54
+ text = ""
55
+ for page in pdf_reader.pages:
56
+ text += page.extract_text() or ""
57
+
58
+ # Process the text normally from here
59
+ texts = text_splitter.split_text(text)
60
+ metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]
61
+ embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
62
+ docsearch = await cl.make_async(Chroma.from_texts)(
63
+ texts, embeddings, metadatas=metadatas
64
+ )
65
+ message_history = ChatMessageHistory()
66
+ memory = ConversationBufferMemory(
67
+ memory_key="chat_history",
68
+ output_key="answer",
69
+ chat_memory=message_history,
70
+ return_messages=True,
71
+ )
72
+
73
+ chain = ConversationalRetrievalChain.from_llm(
74
+ ChatOpenAI(model_name="gpt-4o-mini", temperature=0, streaming=True, openai_api_key=OPENAI_API_KEY),
75
+ chain_type="stuff",
76
+ retriever=docsearch.as_retriever(),
77
+ memory=memory,
78
+ return_source_documents=True,
79
+ )
80
+
81
+ cl.user_session.set("chain", chain)
82
+
83
+ # Answer the template questions
84
+ answers = []
85
+ for question in template_questions:
86
+ res = await chain.acall(question)
87
+ answer = res["answer"]
88
+ source_documents = res["source_documents"] # type: List[Document]
89
+
90
+ text_elements = [] # type: List[cl.Text]
91
+
92
+ if source_documents:
93
+ for source_idx, source_doc in enumerate(source_documents):
94
+ source_name = f"source_{source_idx}"
95
+ # Create the text element referenced in the message
96
+ text_elements.append(
97
+ cl.Text(content=source_doc.page_content, name=source_name)
98
+ )
99
+ source_names = [text_el.name for text_el in text_elements]
100
+
101
+ if source_names:
102
+ answer += f"\nSources: {', '.join(source_names)}"
103
+ else:
104
+ answer += "\nNo sources found"
105
+
106
+ answers.append(answer)
107
+
108
+ for i, question in enumerate(template_questions):
109
+ await cl.Message(content=f"**{question}**\n{answers[i]}").send()
110
+
111
+ msg.content = f"Processing `{file.name}` done. You can now ask more questions!"
112
+ await msg.update()
113
+
114
+ @cl.on_message
115
+ async def main(message: cl.Message):
116
+ chain = cl.user_session.get("chain") # type: ConversationalRetrievalChain
117
+ cb = cl.AsyncLangchainCallbackHandler()
118
+
119
+ res = await chain.acall(message.content, callbacks=[cb])
120
+ answer = res["answer"]
121
+ source_documents = res["source_documents"] # type: List[Document]
122
+
123
+ text_elements = [] # type: List[cl.Text]
124
+
125
+ if source_documents:
126
+ for source_idx, source_doc in enumerate(source_documents):
127
+ source_name = f"source_{source_idx}"
128
+ # Create the text element referenced in the message
129
+ text_elements.append(
130
+ cl.Text(content=source_doc.page_content, name=source_name)
131
+ )
132
+ source_names = [text_el.name for text_el in text_elements]
133
+
134
+ if source_names:
135
+ answer += f"\nSources: {', '.join(source_names)}"
136
+ else:
137
+ answer += "\nNo sources found"
138
+
139
+ await cl.Message(content=answer, elements=text_elements).send()