Spaces:
Sleeping
Sleeping
| # Install necessary packages | |
| #!pip install streamlit | |
| #!pip install wikipedia | |
| #!pip install langchain_community | |
| #!pip install sentence-transformers | |
| #!pip install chromadb | |
| #!pip install huggingface_hub | |
| #!pip install transformers | |
| import streamlit as st | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter | |
| import chromadb | |
| from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
| from huggingface_hub import login, InferenceClient | |
| from sentence_transformers import CrossEncoder | |
| import numpy as np | |
| import random | |
| import string | |
| import tempfile | |
| # User variables | |
| uploaded_file = st.sidebar.file_uploader("Upload your PDF", type="pdf") | |
| model_name = 'mistralai/Mistral-7B-Instruct-v0.3' | |
| HF_TOKEN = st.sidebar.text_input("Enter your Hugging Face token:", "", type="password") | |
| # Initialize session state for error message | |
| if 'error_message' not in st.session_state: | |
| st.session_state.error_message = "" | |
| # Function to validate token | |
| def validate_token(token): | |
| try: | |
| # Attempt to log in with the provided token | |
| login(token=token) | |
| # Check if the token is valid by trying to access some data | |
| HfApi().whoami() | |
| return True | |
| except Exception as e: | |
| return False | |
| # Validate the token and display appropriate message | |
| if HF_TOKEN: | |
| if validate_token(HF_TOKEN): | |
| st.session_state.error_message = "" # Clear error message if the token is valid | |
| st.sidebar.success("Token is valid!") | |
| else: | |
| st.session_state.error_message = "Invalid token. Please try again." | |
| st.sidebar.error(st.session_state.error_message) | |
| elif st.session_state.error_message: | |
| st.sidebar.error(st.session_state.error_message) | |
| if uploaded_file: | |
| # Create a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file: | |
| temp_file.write(uploaded_file.getbuffer()) | |
| temp_file_path = temp_file.name | |
| # Load the PDF using PyPDFLoader | |
| docs = PyPDFLoader(temp_file_path).load() | |
| # The temporary file will be automatically deleted when the application stops | |
| else: | |
| st.warning("Please upload a PDF file.") | |
| # Memory for chat history | |
| if "history" not in st.session_state: | |
| st.session_state.history = [] | |
| # Function to generate a random string for collection name | |
| def generate_random_string(max_length=60): | |
| if max_length > 60: | |
| raise ValueError("The maximum length cannot exceed 60 characters.") | |
| length = random.randint(1, max_length) | |
| characters = string.ascii_letters + string.digits | |
| return ''.join(random.choice(characters) for _ in range(length)) | |
| collection_name = generate_random_string() | |
| # Function for query expansion | |
| def augment_multiple_query(query): | |
| client = InferenceClient(model_name, token=HF_TOKEN) | |
| content = client.chat_completion( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": f"""You are a helpful assistant. | |
| Suggest up to five additional related questions to help them find the information they need for the provided question. | |
| Suggest only short questions without compound sentences. Suggest a variety of questions that cover different aspects of the topic. | |
| Make sure they are complete questions, and that they are related to the original question.""" | |
| }, | |
| { | |
| "role": "user", | |
| "content": query | |
| } | |
| ], | |
| max_tokens=500, | |
| ) | |
| return content.choices[0].message.content.split("\n") | |
| # Function to handle RAG-based question answering | |
| def rag_advanced(user_query): | |
| # Text Splitting | |
| character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0) | |
| concat_texts = "".join([doc.page_content for doc in docs]) | |
| character_split_texts = character_splitter.split_text(concat_texts) | |
| token_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0, tokens_per_chunk=256) | |
| token_split_texts = [text for text in character_split_texts for text in token_splitter.split_text(text)] | |
| # Embedding and Document Storage | |
| embedding_function = SentenceTransformerEmbeddingFunction() | |
| chroma_client = chromadb.Client() | |
| chroma_collection = chroma_client.create_collection(collection_name, embedding_function=embedding_function) | |
| ids = [str(i) for i in range(len(token_split_texts))] | |
| chroma_collection.add(ids=ids, documents=token_split_texts) | |
| # Document Retrieval | |
| augmented_queries = augment_multiple_query(user_query) | |
| joint_query = [user_query] + augmented_queries | |
| results = chroma_collection.query(query_texts=joint_query, n_results=5, include=['documents', 'embeddings']) | |
| retrieved_documents = results['documents'] | |
| unique_documents = list(set(doc for docs in retrieved_documents for doc in docs)) | |
| # Re-Ranking | |
| cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
| pairs = [[user_query, doc] for doc in unique_documents] | |
| scores = cross_encoder.predict(pairs) | |
| top_indices = np.argsort(scores)[::-1][:5] | |
| top_documents = [unique_documents[idx] for idx in top_indices] | |
| # LLM Reference | |
| client = InferenceClient(model_name, token=HF_TOKEN) | |
| response = "" | |
| for message in client.chat_completion( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": f"""You are a helpful assitant. | |
| You will be shown the user's questions, and the relevant information from the related documents. | |
| Answer the user's question using only this information.""" | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Questions: {user_query}. \n Information: {top_documents}" | |
| } | |
| ], | |
| max_tokens=500, | |
| stream=True, | |
| ): | |
| response += message.choices[0].delta.content | |
| return response | |
| # Streamlit UI | |
| st.title("PDF RAG Chatbot") | |
| st.markdown("Upload your PDF and enter your 🤗 token!") | |
| st.link_button("Get Token Here", "https://huggingface.co/settings/tokens") | |
| # Input box for the user to type their message | |
| if uploaded_file: | |
| user_input = st.text_input("You: ", "", placeholder="Type your question here...") | |
| if user_input: | |
| response = rag_advanced(user_input) | |
| st.session_state.history.append({"user": user_input, "bot": response}) | |
| # Display the conversation history | |
| for chat in st.session_state.history: | |
| st.write(f"You: {chat['user']}") | |
| st.write(f"Bot: {chat['bot']}") | |
| st.markdown("-----------------") | |
| st.markdown("What is this app?") | |
| st.markdown("""This is a simple RAG application using PDF import. | |
| The model for chat is Mistral-7B-Instruct-v0.3. | |
| Main libraries: Langchain (text splitting), Chromadb (vector store) | |
| This RAG uses query expansion and re-ranking to improve the quality. | |
| Feel free to check the files or DM me for any questions. Thank you.""") |