Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import tempfile | |
| import pickle | |
| import faiss | |
| import numpy as np | |
| from helper import extract_text_from_pdf, chunk_text, embedding_function, embedding_model, query_llm_with_context | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="PDF RAG System", | |
| page_icon="π", | |
| layout="wide" | |
| ) | |
| # Title and description | |
| st.title("π PDF RAG System") | |
| st.markdown(""" | |
| This application allows you to upload a PDF file, ask questions about its content, and get AI-generated answers based on the document. | |
| """) | |
| # File upload section | |
| st.header("1. Upload PDF") | |
| uploaded_file = st.file_uploader("Choose a PDF file", type="pdf", key="pdf_uploader") | |
| # Initialize session state variables | |
| if 'pdf_processed' not in st.session_state: | |
| st.session_state.pdf_processed = False | |
| if 'index' not in st.session_state: | |
| st.session_state.index = None | |
| if 'chunks' not in st.session_state: | |
| st.session_state.chunks = None | |
| if 'pdf_path' not in st.session_state: | |
| st.session_state.pdf_path = None | |
| # Process the uploaded PDF | |
| if uploaded_file is not None and not st.session_state.pdf_processed: | |
| with st.spinner("Processing PDF..."): | |
| # Create a temporary file to save the uploaded PDF | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: | |
| tmp_file.write(uploaded_file.getvalue()) | |
| st.session_state.pdf_path = tmp_file.name | |
| # Extract text from PDF | |
| pdf_text = extract_text_from_pdf(st.session_state.pdf_path) | |
| # Chunk the text | |
| chunks = chunk_text(pdf_text, chunk_size=1000, chunk_overlap=100) | |
| st.session_state.chunks = chunks | |
| # Create embeddings | |
| embeddings = embedding_function(chunks) | |
| # Convert embeddings to numpy array if they aren't already | |
| if not isinstance(embeddings, np.ndarray): | |
| embeddings = np.array(embeddings).astype('float32') | |
| # Get the dimension of the embeddings | |
| dimension = embeddings.shape[1] | |
| # Initialize FAISS index | |
| index = faiss.IndexFlatL2(dimension) | |
| # Add vectors to the index | |
| index.add(embeddings) | |
| # Save the index and chunks | |
| faiss.write_index(index, "./faiss_index") | |
| with open("./document_chunks.pkl", 'wb') as f: | |
| pickle.dump(chunks, f) | |
| # Update session state | |
| st.session_state.index = index | |
| st.session_state.pdf_processed = True | |
| st.success(f"PDF processed successfully! {len(chunks)} chunks created.") | |
| # Query section | |
| st.header("2. Ask a Question") | |
| query = st.text_input("Enter your question about the PDF content:", key="query_input") | |
| # Add a button to submit the query | |
| if st.button("Get Answer", key="get_answer_button") and query and st.session_state.pdf_processed: | |
| with st.spinner("Retrieving relevant information and generating answer..."): | |
| try: | |
| # Generate embedding for the query | |
| query_embedding = embedding_model.encode([query], convert_to_numpy=True).astype('float32') | |
| # Search the index | |
| n_results = 5 | |
| distances, indices = st.session_state.index.search(query_embedding, n_results) | |
| # Get the documents | |
| documents = [st.session_state.chunks[i] for i in indices[0]] | |
| # Convert distances to similarity scores (L2 distance: lower is better) | |
| # Normalize distances to [0, 1] range where 1 is most similar | |
| max_distance = np.max(distances) | |
| similarity_scores = [1 - (dist / max_distance) for dist in distances[0]] | |
| # Create context from retrieved documents | |
| context = (documents, similarity_scores) | |
| # Query the LLM with context | |
| answer = query_llm_with_context(query, context, top_n=3) | |
| # Display the answer | |
| st.header("3. Answer") | |
| st.write(answer) | |
| # Display the retrieved documents | |
| with st.expander("View Retrieved Documents", expanded=False): | |
| for i, (doc, score) in enumerate(zip(documents, similarity_scores)): | |
| st.markdown(f"**Document {i+1}** (Relevance: {score:.4f})") | |
| st.text(doc[:500] + "..." if len(doc) > 500 else doc) | |
| st.markdown("---") | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| logger.exception("Error during query processing") | |
| # Add a reset button | |
| if st.button("Reset and Upload New PDF", key="reset_button"): | |
| # Clean up temporary files | |
| if st.session_state.pdf_path and os.path.exists(st.session_state.pdf_path): | |
| os.unlink(st.session_state.pdf_path) | |
| # Reset session state | |
| st.session_state.pdf_processed = False | |
| st.session_state.index = None | |
| st.session_state.chunks = None | |
| st.session_state.pdf_path = None | |
| # Reload the page | |
| st.experimental_rerun() | |
| # Footer | |
| st.markdown("---") | |
| st.markdown("Built with Streamlit, FAISS, and Hugging Face API") |