Spaces:
Build error
Build error
| import os | |
| import streamlit as st | |
| from langchain.chains import RetrievalQA | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import Chroma | |
| from langchain.llms import HuggingFaceHub | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from dotenv import load_dotenv | |
| from PyPDF2 import PdfReader | |
| # Load environment variables | |
| load_dotenv() | |
| # Define your supported models | |
| model_links = { | |
| "Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct", | |
| "Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2", | |
| "Gemma-7B": "google/gemma-1.1-7b-it", | |
| "Gemma-2B": "google/gemma-1.1-2b-it", | |
| "Zephyr-7B-β": "HuggingFaceH4/zephyr-7b-beta", | |
| } | |
| # Function to read PDF files and extract text along with their names | |
| def read_pdf_files(directory): | |
| documents = [] | |
| for filename in os.listdir(directory): | |
| if filename.endswith(".pdf"): | |
| with open(os.path.join(directory, filename), "rb") as file: | |
| reader = PdfReader(file) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() | |
| documents.append((filename, text)) | |
| return documents | |
| # Initialize ChromaDB with PDF data | |
| def initialize_chromadb_from_pdfs(directory): | |
| documents = read_pdf_files(directory) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| split_docs = [] | |
| for doc_name, doc_text in documents: | |
| chunks = text_splitter.split_text(doc_text) | |
| split_docs.extend([(doc_name, chunk) for chunk in chunks]) | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| chromadb = Chroma.from_texts([chunk for _, chunk in split_docs], embeddings, metadatas=[{"source": name} for name, _ in split_docs]) | |
| return chromadb, split_docs | |
| # Initialize the ChromaDB retriever | |
| chromadb, split_docs = initialize_chromadb_from_pdfs("docs") | |
| retriever = chromadb.as_retriever(search_type="similarity", search_kwargs={"k": 5}) | |
| # Create the sidebar with the dropdown for model selection | |
| selected_model = st.sidebar.selectbox("Select Model", model_links.keys()) | |
| # Create temperature slider | |
| temp_values = st.sidebar.slider('Select a temperature value', 0.0, 1.0, 0.5) | |
| # Add reset button to clear conversation | |
| st.sidebar.button('Reset Chat', on_click=lambda: st.session_state.clear()) | |
| # Pull in the selected model | |
| repo_id = model_links[selected_model] | |
| # Initialize chat history | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Display chat messages from history on app rerun | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Accept user input | |
| if prompt := st.chat_input(f"Hi I'm {selected_model}, ask me a question"): | |
| # Display user message in chat message container | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Perform RAG | |
| with st.spinner('Processing query with RAG...'): | |
| llm = HuggingFaceHub(repo_id=repo_id) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=True | |
| ) | |
| response = qa_chain({"query": prompt}) | |
| helpful_answer = response['result'] | |
| source_documents = response['source_documents'] | |
| # Ensure the answer is complete by checking for truncation | |
| if helpful_answer.endswith("..."): | |
| # If truncated, try to get more context from the source documents | |
| for doc in source_documents: | |
| if doc is not None: | |
| doc_name = doc.metadata["source"] | |
| doc_text = next((text for name, text in read_pdf_files("docs") if name == doc_name), "") | |
| # Extract relevant context | |
| start_idx = doc_text.find(helpful_answer) | |
| if start_idx != -1: | |
| end_idx = start_idx + len(helpful_answer) + 100 # Add some extra context | |
| helpful_answer += "\n\n" + doc_text[start_idx:end_idx] | |
| # Display assistant response | |
| with st.chat_message("assistant"): | |
| st.markdown(helpful_answer) | |
| # Display references in an expander | |
| if source_documents: | |
| with st.expander("References", expanded=False): | |
| for doc in source_documents: | |
| doc_name = doc.metadata["source"] | |
| st.markdown(f"- **{doc_name}**") | |
| # Only add the helpful answer and references to the session state | |
| references = "\n".join([f"- **{doc.metadata['source']}**" for doc in source_documents if doc]) | |
| st.session_state.messages.append({"role": "assistant", "content": f"{helpful_answer}\n\n**References:**\n{references}"}) |