Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.llms import CTransformers | |
| from langchain.chains import RetrievalQA | |
| from huggingface_hub import snapshot_download | |
| DB_FAISS_PATH = './faiss_data' | |
| # Define your custom prompt template for the LLM | |
| custom_prompt_template = """Use the following pieces of information to answer the user's question. If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
| Context: {context} | |
| Question: {question} | |
| Only return the helpful answer below and nothing else. | |
| Helpful answer: """ | |
| def set_custom_prompt(): | |
| prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question']) | |
| return prompt | |
| # Function to create a retrieval-based QA chain | |
| def retrieval_qa_chain(llm, prompt, db): | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type='stuff', | |
| retriever=db.as_retriever(search_kwargs={'k': 2}), | |
| return_source_documents=True, | |
| chain_type_kwargs={'prompt': prompt} | |
| ) | |
| return qa_chain | |
| # Load the LLM model (e.g., LLaMA model from a local path) | |
| def load_llm(): | |
| model_path = "./llama-2-7b-chat.ggmlv3.q4_0.bin" | |
| llm = CTransformers( | |
| model=model_path, | |
| model_type="llama", | |
| max_new_tokens=1024, | |
| temperature=0.5 | |
| ) | |
| return llm | |
| # Main chatbot logic | |
| def qa_bot(): | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) | |
| llm = load_llm() | |
| qa_prompt = set_custom_prompt() | |
| qa = retrieval_qa_chain(llm, qa_prompt, db) | |
| return qa | |
| # Streamlit main app | |
| def main(): | |
| st.title("RAG") | |
| # Initialize session state for chat history | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Display chat history | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Chat input | |
| if prompt := st.chat_input("What is your medical query?"): | |
| # Display user message | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Generate and display assistant response | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| qa_chain = qa_bot() | |
| response = qa_chain({'query': prompt}) | |
| st.markdown(response["result"]) | |
| st.session_state.messages.append({"role": "assistant", "content": response["result"]}) | |
| if __name__ == '__main__': | |
| main() | |