Spaces:
Sleeping
Sleeping
| """ | |
| The code in this script subjects to a licence of 96harsh52/LLaMa_2_chatbot (https://github.com/96harsh52/LLaMa_2_chatbot) | |
| Youtube instruction (https://www.youtube.com/watch?v=kXuHxI5ZcG0&list=PLrLEqwuz-mRIdQrfeCjeCyFZ-Pl6ffPIN&index=18) | |
| Llama 2 Model (Quantized one by the Bloke): https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/blob/main/llama-2-7b-chat.ggmlv3.q8_0.bin | |
| Llama 2 HF Model (Original One): https://huggingface.co/meta-llama | |
| Chainlit docs: https://github.com/Chainlit/chainlit | |
| """ | |
| from langchain import PromptTemplate | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from langchain_community.llms import CTransformers | |
| import chainlit as cl | |
| DB_FAISS_PATH = 'vectorstore/db_faiss' | |
| custom_prompt_template = """Use the following pieces of information to answer the user's question. | |
| If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
| Context: {context} | |
| Question: {question} | |
| Only return the helpful answer below and nothing else. | |
| Helpful answer: | |
| """ | |
| def set_custom_prompt(): | |
| """ | |
| Prompt template for QA retrieval for each vectorstore | |
| """ | |
| prompt = PromptTemplate(template=custom_prompt_template, | |
| input_variables=['context', 'question']) | |
| return prompt | |
| def load_llm(): | |
| """ | |
| Load the language model | |
| """ | |
| llm = CTransformers(model='TheBloke/Llama-2-7b-Chat-GGUF', | |
| model_file='llama-2-7b-chat.Q8_0.gguf', | |
| model_type='llama', | |
| max_new_tokens=512, | |
| temperature=0.5) | |
| return llm | |
| def retrieval_qa_chain(llm, prompt, db): | |
| """ | |
| Create a retrieval QA chain | |
| """ | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type='stuff', | |
| retriever=db.as_retriever(search_kwargs={'k': 2}), | |
| return_source_documents=True, | |
| chain_type_kwargs={'prompt': prompt} | |
| ) | |
| return qa_chain | |
| def qa_bot(): | |
| """ | |
| Create a QA bot | |
| """ | |
| embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2', | |
| model_kwargs={'device': 'cpu'}) | |
| db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) | |
| llm = load_llm() | |
| qa_prompt = set_custom_prompt() | |
| qa = retrieval_qa_chain(llm, qa_prompt, db) | |
| return qa | |
| def final_result(query): | |
| qa_result = qa_bot() | |
| response = qa_result({'query': query}) | |
| return response | |
| async def start(): | |
| chain = qa_bot() | |
| msg = cl.Message(content="Starting the bot...") | |
| await msg.send() | |
| msg.content = "Hi, Welcome to Medical Chatbot. What is your query?" | |
| await msg.update() | |
| cl.user_session.set("chain", chain) | |
| async def main(message: cl.Message): | |
| chain = cl.user_session.get("chain") | |
| cb = cl.AsyncLangchainCallbackHandler( | |
| stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"] | |
| ) | |
| cb.answer_reached = True | |
| res = await chain.acall(message.content, callbacks=[cb]) | |
| answer = res["result"] | |
| sources = res["source_documents"] | |
| if sources: | |
| answer += f"\nSources:" + str(sources) | |
| else: | |
| answer += "\nNo sources found" | |
| await cl.Message(content=answer).send() | |