Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI | |
| from langchain_community.document_loaders import PyMuPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_chroma import Chroma | |
| from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain.prompts.chat import MessagesPlaceholder, ChatPromptTemplate | |
| import os | |
| # Google API key | |
| os.environ["GOOGLE_API_KEY"]=os.getenv("TOKEN") | |
| # Initialize the Gemini embeddings model and chat model | |
| gemini_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") | |
| model = ChatGoogleGenerativeAI(model="gemini-1.0-pro", convert_system_message_to_human=True) | |
| # Load and process PDF document | |
| pdf_loader = PyMuPDFLoader(file_path="./Debyez detialed proposal 2.pdf") # Ensure this file is in your Space | |
| doc = pdf_loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| splits = text_splitter.split_documents(doc) | |
| # Create a vector store and retriever | |
| # Use Chroma in-memory mode for temporary storage | |
| vectorstore = Chroma.from_documents(documents=splits, embedding=gemini_embeddings, persist_directory=None) | |
| retriever = vectorstore.as_retriever() | |
| from langchain.prompts.chat import ChatPromptTemplate, MessagesPlaceholder | |
| # Define system prompt, ensuring that `context` is specified as an input | |
| system_prompt = ( | |
| "You are an assistant for question-answering tasks. " | |
| "Use the following pieces of retrieved context to answer the question. " | |
| "If you don't know the answer, say that you don't know. " | |
| "Use three sentences maximum and keep the answer concise." | |
| "\n\n" | |
| "{context}" | |
| ) | |
| qa_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| # Ensure the question_answer_chain uses 'context' as an expected input variable | |
| question_answer_chain = create_stuff_documents_chain(model, qa_prompt, document_variable_name="context") | |
| retriever_prompt = ( | |
| "Given a chat history and the latest user question which might reference context in the chat history," | |
| "formulate a standalone question which can be understood without the chat history." | |
| ) | |
| contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", retriever_prompt), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| history_aware_retriever = create_history_aware_retriever(model, retriever, contextualize_q_prompt) | |
| qa_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| question_answer_chain = create_stuff_documents_chain(model, qa_prompt) | |
| rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
| store = {} | |
| def get_session_history(session_id: str) -> BaseChatMessageHistory: | |
| if session_id not in store: | |
| store[session_id] = ChatMessageHistory() | |
| return store[session_id] | |
| conversational_rag_chain = RunnableWithMessageHistory( | |
| rag_chain, | |
| get_session_history, | |
| input_messages_key="input", | |
| history_messages_key="chat_history", | |
| output_messages_key="answer", | |
| ) | |
| # Streamlit UI | |
| st.title("Conversational RAG Assistant") | |
| st.write("Ask questions based on the context in the uploaded PDF document.") | |
| # Session ID for chat history | |
| session_id = "user_session" | |
| # Display chat history | |
| if "history" not in st.session_state: | |
| st.session_state["history"] = [] | |
| # Get user input | |
| user_input = st.text_input("Enter your question:", "") | |
| if st.button("Submit"): | |
| if user_input: | |
| # Update chat history | |
| st.session_state["history"].append({"role": "user", "content": user_input}) | |
| # Query the RAG chain | |
| result = conversational_rag_chain.invoke( | |
| {"input": user_input}, | |
| config={"configurable": {"session_id": session_id}} | |
| )["answer"] | |
| # Append model response to history | |
| st.session_state["history"].append({"role": "assistant", "content": result}) | |
| # Clear the input box after submitting | |
| user_input = "" | |
| # Display conversation history | |
| for message in st.session_state["history"]: | |
| if message["role"] == "user": | |
| st.markdown(f"**You:** {message['content']}") | |
| else: | |
| st.markdown(f"**Assistant:** {message['content']}") | |