araeyn commited on
Commit
a23b8d7
·
verified ·
1 Parent(s): 2e7f253

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -110
app.py CHANGED
@@ -47,120 +47,117 @@ async def echo(websocket):
47
  async def main():
48
  async with serve(echo, "0.0.0.0", 7860):
49
  await asyncio.Future()
50
- def g():
51
- global retriever, conversational_rag_chain
52
- if not os.path.isdir('database'):
53
- os.system("unzip database.zip")
54
-
55
- loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader)
56
-
57
- documents = loader.load()
58
-
59
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
60
- splits = text_splitter.split_documents(documents)
61
-
62
- print()
63
- print("-------")
64
- print("TextSplitter, DirectoryLoader")
65
- print("-------")
66
-
67
- persist_directory = 'db'
68
-
69
- # embedding = HuggingFaceInferenceAPIEmbeddings(api_key=os.environ["HUGGINGFACE_API_KEY"], model=)
70
- model_name = "BAAI/bge-large-en"
71
- model_kwargs = {'device': 'cpu'}
72
- encode_kwargs = {'normalize_embeddings': True}
73
- embedding = HuggingFaceBgeEmbeddings(
74
- model_name=model_name,
75
- model_kwargs=model_kwargs,
76
- encode_kwargs=encode_kwargs,
77
- show_progress=True,
78
- )
79
-
80
- print()
81
- print("-------")
82
- print("Embeddings")
83
- print("-------")
84
-
85
- vectorstore = Chroma.from_documents(documents=splits, embedding=embedding)
86
-
87
- def format_docs(docs):
88
- return "\n\n".join(doc.page_content for doc in docs)
89
-
90
- retriever = vectorstore.as_retriever()
91
-
92
- prompt = hub.pull("rlm/rag-prompt")
93
- llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")
94
- rag_chain = (
95
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
96
- | prompt
97
- | llm
98
- | StrOutputParser()
99
- )
100
-
101
- print()
102
- print("-------")
103
- print("Retriever, Prompt, LLM, Rag_Chain")
104
- print("-------")
105
-
106
- ### Contextualize question ###
107
- contextualize_q_system_prompt = """Given a chat history and the latest user question \
108
- which might reference context in the chat history, formulate a standalone question \
109
- which can be understood without the chat history. Do NOT answer the question, \
110
- just reformulate it if needed and otherwise return it as is."""
111
- contextualize_q_prompt = ChatPromptTemplate.from_messages(
112
- [
113
- ("system", contextualize_q_system_prompt),
114
- MessagesPlaceholder("chat_history"),
115
- ("human", "{input}"),
116
- ]
117
- )
118
- history_aware_retriever = create_history_aware_retriever(
119
- llm, retriever, contextualize_q_prompt
120
- )
121
-
122
-
123
- ### Answer question ###
124
- qa_system_prompt = """You are an assistant for question-answering tasks. \
125
- Use the following pieces of retrieved context to answer the question. \
126
- If you don't know the answer, just say that you don't know. \
127
- Use three sentences maximum and keep the answer concise.\
128
-
129
- {context}"""
130
- qa_prompt = ChatPromptTemplate.from_messages(
131
- [
132
- ("system", qa_system_prompt),
133
- MessagesPlaceholder("chat_history"),
134
- ("human", "{input}"),
135
- ]
136
- )
137
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
138
-
139
- rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
140
-
141
-
142
- ### Statefully manage chat history ###
143
- store = {}
144
-
145
-
146
- def get_session_history(session_id: str) -> BaseChatMessageHistory:
147
- if session_id not in store:
148
- store[session_id] = ChatMessageHistory()
149
- return store[session_id]
150
-
151
-
152
- conversational_rag_chain = RunnableWithMessageHistory(
153
- rag_chain,
154
- get_session_history,
155
- input_messages_key="input",
156
- history_messages_key="chat_history",
157
- output_messages_key="answer",
158
- )
159
 
160
  def f():
161
  asyncio.run(main())
162
  Process(target=f).start()
163
- Process(target=g).start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  """
165
  websocket
166
  streamlit app ~> backend
 
47
  async def main():
48
  async with serve(echo, "0.0.0.0", 7860):
49
  await asyncio.Future()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def f():
52
  asyncio.run(main())
53
  Process(target=f).start()
54
+ if not os.path.isdir('database'):
55
+ os.system("unzip database.zip")
56
+
57
+ loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader)
58
+
59
+ documents = loader.load()
60
+
61
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
62
+ splits = text_splitter.split_documents(documents)
63
+
64
+ print()
65
+ print("-------")
66
+ print("TextSplitter, DirectoryLoader")
67
+ print("-------")
68
+
69
+ persist_directory = 'db'
70
+
71
+ # embedding = HuggingFaceInferenceAPIEmbeddings(api_key=os.environ["HUGGINGFACE_API_KEY"], model=)
72
+ model_name = "BAAI/bge-large-en"
73
+ model_kwargs = {'device': 'cpu'}
74
+ encode_kwargs = {'normalize_embeddings': True}
75
+ embedding = HuggingFaceBgeEmbeddings(
76
+ model_name=model_name,
77
+ model_kwargs=model_kwargs,
78
+ encode_kwargs=encode_kwargs,
79
+ show_progress=True,
80
+ )
81
+
82
+ print()
83
+ print("-------")
84
+ print("Embeddings")
85
+ print("-------")
86
+
87
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embedding)
88
+
89
+ def format_docs(docs):
90
+ return "\n\n".join(doc.page_content for doc in docs)
91
+
92
+ retriever = vectorstore.as_retriever()
93
+
94
+ prompt = hub.pull("rlm/rag-prompt")
95
+ llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")
96
+ rag_chain = (
97
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
98
+ | prompt
99
+ | llm
100
+ | StrOutputParser()
101
+ )
102
+
103
+ print()
104
+ print("-------")
105
+ print("Retriever, Prompt, LLM, Rag_Chain")
106
+ print("-------")
107
+
108
+ ### Contextualize question ###
109
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
110
+ which might reference context in the chat history, formulate a standalone question \
111
+ which can be understood without the chat history. Do NOT answer the question, \
112
+ just reformulate it if needed and otherwise return it as is."""
113
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
114
+ [
115
+ ("system", contextualize_q_system_prompt),
116
+ MessagesPlaceholder("chat_history"),
117
+ ("human", "{input}"),
118
+ ]
119
+ )
120
+ history_aware_retriever = create_history_aware_retriever(
121
+ llm, retriever, contextualize_q_prompt
122
+ )
123
+
124
+
125
+ ### Answer question ###
126
+ qa_system_prompt = """You are an assistant for question-answering tasks. \
127
+ Use the following pieces of retrieved context to answer the question. \
128
+ If you don't know the answer, just say that you don't know. \
129
+ Use three sentences maximum and keep the answer concise.\
130
+
131
+ {context}"""
132
+ qa_prompt = ChatPromptTemplate.from_messages(
133
+ [
134
+ ("system", qa_system_prompt),
135
+ MessagesPlaceholder("chat_history"),
136
+ ("human", "{input}"),
137
+ ]
138
+ )
139
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
140
+
141
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
142
+
143
+
144
+ ### Statefully manage chat history ###
145
+ store = {}
146
+
147
+
148
+ def get_session_history(session_id: str) -> BaseChatMessageHistory:
149
+ if session_id not in store:
150
+ store[session_id] = ChatMessageHistory()
151
+ return store[session_id]
152
+
153
+
154
+ conversational_rag_chain = RunnableWithMessageHistory(
155
+ rag_chain,
156
+ get_session_history,
157
+ input_messages_key="input",
158
+ history_messages_key="chat_history",
159
+ output_messages_key="answer",
160
+ )
161
  """
162
  websocket
163
  streamlit app ~> backend