Spaces:
Runtime error
Runtime error
| import asyncio | |
| import json | |
| from websockets.server import serve | |
| from langchain.vectorstores import Chroma | |
| from langchain_huggingface.embeddings import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface.llms import HuggingFaceEndpoint | |
| from langchain.document_loaders import TextLoader | |
| from langchain.document_loaders import DirectoryLoader | |
| from langchain import hub | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain.chains import create_history_aware_retriever | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain.chains import create_retrieval_chain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| loader = DirectoryLoader('./database', glob="./*.txt", loader_cls=TextLoader) | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| texts = text_splitter.split_documents(documents) | |
| persist_directory = 'db' | |
| embedding = HuggingFaceEmbeddings() | |
| vectordb = Chroma.from_documents(documents=texts, | |
| embedding=embedding, | |
| persist_directory=persist_directory) | |
| vectordb.persist() | |
| vectordb = None | |
| vectordb = Chroma(persist_directory=persist_directory, | |
| embedding_function=embedding) | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| retriever = vectordb.as_retriever() | |
| prompt = hub.pull("rlm/rag-prompt") | |
| llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1") | |
| rag_chain = ( | |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| contextualize_q_system_prompt = """Given a chat history and the latest user question \ | |
| which might reference context in the chat history, formulate a standalone question \ | |
| which can be understood without the chat history. Do NOT answer the question, \ | |
| just reformulate it if needed and otherwise return it as is.""" | |
| contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", contextualize_q_system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| history_aware_retriever = create_history_aware_retriever( | |
| llm, retriever, contextualize_q_prompt | |
| ) | |
| qa_system_prompt = """You are an assistant for question-answering tasks. \ | |
| Use the following pieces of retrieved context to answer the question. \ | |
| If you don't know the answer, just say that you don't know. \ | |
| Use three sentences maximum and keep the answer concise.\ | |
| {context}""" | |
| qa_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", qa_system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| store = {} | |
| def get_session_history(session_id: str) -> BaseChatMessageHistory: | |
| if session_id not in store: | |
| store[session_id] = ChatMessageHistory() | |
| return store[session_id] | |
| question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
| rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
| conversational_rag_chain = RunnableWithMessageHistory( | |
| rag_chain, | |
| get_session_history, | |
| input_messages_key="input", | |
| history_messages_key="chat_history", | |
| output_messages_key="answer", | |
| ) | |
| print("-------") | |
| print("started") | |
| print("-------") | |
| async def echo(websocket): | |
| async for message in websocket: | |
| data = json.loads(message) | |
| if not "message" in message: | |
| return | |
| if not "token" in message: | |
| return | |
| m = data["message"] | |
| token = data["token"] | |
| userData = json.load(open("userData.json", "w")) | |
| docs = retriever.get_relevant_documents(m) | |
| userData[token]["docs"] = str(docs) | |
| response = conversational_rag_chain.invoke( | |
| {"input": m}, | |
| config={ | |
| "configurable": {"session_id": token} | |
| }, | |
| )["answer"] | |
| await websocket.send(json.dumps({"response": response})) | |
| async def main(): | |
| async with serve(echo, "0.0.0.0", 7860): | |
| await asyncio.Future() | |
| asyncio.run(main()) | |