""" RAG System Demo - Streamlit Application A self-contained RAG (Retrieval-Augmented Generation) demo using google/flan-t5-small for text generation and sentence-transformers for semantic search with ChromaDB as the vector store. """ import io import os import uuid import logging from typing import Optional import numpy as np import pandas as pd import streamlit as st logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Custom CSS # --------------------------------------------------------------------------- CUSTOM_CSS = """ """ # --------------------------------------------------------------------------- # RAG System # --------------------------------------------------------------------------- class RAGSystem: """Core RAG pipeline: document processing, retrieval, and generation.""" def __init__(self) -> None: """Initialize the RAG system (models loaded separately via cache).""" self.chunk_size = 500 self.chunk_overlap = 50 # ------------------------------------------------------------------ # Model loading (cached by Streamlit) # ------------------------------------------------------------------ @staticmethod @st.cache_resource def load_models() -> tuple: """Load flan-t5-small and sentence-transformers. Returns: Tuple of (text-generation pipeline, SentenceTransformer model). """ from transformers import pipeline as hf_pipeline from sentence_transformers import SentenceTransformer with st.spinner("Loading language model (flan-t5-small)..."): llm = hf_pipeline( "text-generation", model="google/flan-t5-small", max_new_tokens=200, do_sample=False, ) with st.spinner("Loading embedding model (all-MiniLM-L6-v2)..."): embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") return llm, embedder # ------------------------------------------------------------------ # Vector store # ------------------------------------------------------------------ @staticmethod def setup_vector_store(): """Create an in-memory ChromaDB client and collection. Returns: Tuple of (chromadb.Client, Collection). """ import chromadb client = chromadb.Client() collection = client.get_or_create_collection( name="rag_documents", metadata={"hnsw:space": "cosine"}, ) return client, collection # ------------------------------------------------------------------ # Text extraction # ------------------------------------------------------------------ def extract_text(self, uploaded_file) -> dict: """Extract text content from an uploaded file. Args: uploaded_file: Streamlit UploadedFile object. Returns: Dict with keys 'filename', 'text', and 'type'. Raises: ValueError: If the file type is unsupported or extraction fails. """ filename = uploaded_file.name file_ext = os.path.splitext(filename)[1].lower() if file_ext == ".txt": text = self._extract_txt(uploaded_file) elif file_ext == ".pdf": text = self._extract_pdf(uploaded_file) elif file_ext == ".docx": text = self._extract_docx(uploaded_file) elif file_ext == ".csv": text = self._extract_csv(uploaded_file) else: raise ValueError(f"Unsupported file type: {file_ext}") if not text or not text.strip(): raise ValueError(f"No text content could be extracted from {filename}") return { "filename": filename, "text": text.strip(), "type": file_ext.lstrip("."), } @staticmethod def _extract_txt(uploaded_file) -> str: raw = uploaded_file.read() for encoding in ("utf-8", "latin-1", "cp1252"): try: return raw.decode(encoding) except (UnicodeDecodeError, AttributeError): continue return raw.decode("utf-8", errors="replace") @staticmethod def _extract_pdf(uploaded_file) -> str: try: from PyPDF2 import PdfReader except ImportError as exc: raise ValueError("PyPDF2 is required for PDF processing") from exc reader = PdfReader(io.BytesIO(uploaded_file.read())) pages = [] for page in reader.pages: page_text = page.extract_text() if page_text: pages.append(page_text) return "\n\n".join(pages) @staticmethod def _extract_docx(uploaded_file) -> str: try: from docx import Document except ImportError as exc: raise ValueError("python-docx is required for DOCX processing") from exc doc = Document(io.BytesIO(uploaded_file.read())) paragraphs = [p.text for p in doc.paragraphs if p.text.strip()] return "\n\n".join(paragraphs) @staticmethod def _extract_csv(uploaded_file) -> str: df = pd.read_csv(uploaded_file) rows = [] for _, row in df.iterrows(): parts = [f"{col}: {val}" for col, val in row.items() if pd.notna(val)] rows.append(". ".join(parts)) return "\n\n".join(rows) # ------------------------------------------------------------------ # Chunking # ------------------------------------------------------------------ def chunk_text( self, text: str, chunk_size: int = 500, overlap: int = 50 ) -> list[str]: """Split text into overlapping chunks, breaking at sentence boundaries. Args: text: The full document text. chunk_size: Maximum characters per chunk. overlap: Number of overlapping characters between chunks. Returns: List of text chunks. """ if not text or not text.strip(): return [] sentences = self._split_sentences(text) chunks: list[str] = [] current_chunk: list[str] = [] current_length = 0 for sentence in sentences: sentence = sentence.strip() if not sentence: continue sentence_len = len(sentence) if current_length + sentence_len > chunk_size and current_chunk: chunks.append(" ".join(current_chunk)) # Keep tail sentences for overlap overlap_chunk: list[str] = [] overlap_len = 0 for s in reversed(current_chunk): if overlap_len + len(s) > overlap: break overlap_chunk.insert(0, s) overlap_len += len(s) current_chunk = overlap_chunk current_length = overlap_len current_chunk.append(sentence) current_length += sentence_len if current_chunk: chunks.append(" ".join(current_chunk)) return chunks @staticmethod def _split_sentences(text: str) -> list[str]: """Naive sentence splitter on '.', '!', '?'.""" import re sentences = re.split(r"(?<=[.!?])\s+", text) return [s for s in sentences if s.strip()] # ------------------------------------------------------------------ # Document ingestion # ------------------------------------------------------------------ def add_document(self, doc_data: dict, embedder, collection) -> int: """Chunk, embed, and add a document to the vector store. Args: doc_data: Dict with 'filename' and 'text'. embedder: SentenceTransformer instance. collection: ChromaDB collection. Returns: Number of chunks added. """ chunks = self.chunk_text(doc_data["text"], self.chunk_size, self.chunk_overlap) if not chunks: return 0 embeddings = embedder.encode(chunks, show_progress_bar=False) ids = [f"{doc_data['filename']}_{uuid.uuid4().hex[:8]}" for _ in chunks] metadatas = [ { "source": doc_data["filename"], "chunk_index": i, "total_chunks": len(chunks), } for i in range(len(chunks)) ] collection.add( ids=ids, embeddings=embeddings.tolist(), documents=chunks, metadatas=metadatas, ) return len(chunks) # ------------------------------------------------------------------ # Search # ------------------------------------------------------------------ def search( self, query: str, embedder, collection, n_results: int = 5, ) -> list[dict]: """Perform semantic search over the vector store. Args: query: User query string. embedder: SentenceTransformer instance. collection: ChromaDB collection. n_results: Maximum number of results to return. Returns: List of dicts with 'text', 'source', 'similarity', and 'chunk_index'. """ if collection.count() == 0: return [] query_embedding = embedder.encode([query], show_progress_bar=False) actual_n = min(n_results, collection.count()) results = collection.query( query_embeddings=query_embedding.tolist(), n_results=actual_n, include=["documents", "metadatas", "distances"], ) formatted: list[dict] = [] if results and results["documents"]: for doc, meta, dist in zip( results["documents"][0], results["metadatas"][0], results["distances"][0], ): # ChromaDB cosine distance is in [0, 2]; convert to similarity similarity = max(0.0, 1.0 - dist) formatted.append( { "text": doc, "source": meta.get("source", "Unknown"), "similarity": round(similarity, 4), "chunk_index": meta.get("chunk_index", 0), } ) formatted.sort(key=lambda x: x["similarity"], reverse=True) return formatted # ------------------------------------------------------------------ # Response generation # ------------------------------------------------------------------ def generate_response( self, query: str, context_docs: list[dict], llm_pipeline, ) -> str: """Generate an answer with flan-t5-small using retrieved context. Args: query: User question. context_docs: Retrieved documents from search(). llm_pipeline: HuggingFace text-generation pipeline. Returns: Generated answer string. """ if not context_docs: return ( "I don't have any documents to reference. " "Please upload documents first, then ask your question." ) # Build context from top results context_parts: list[str] = [] for doc in context_docs[:3]: text = doc["text"][:400] context_parts.append(text) context = "\n\n".join(context_parts) prompt = ( "Answer the following question based on the provided context.\n\n" f"Context:\n{context}\n\n" f"Question: {query}" ) try: result = llm_pipeline(prompt, max_new_tokens=200) answer = result[0]["generated_text"].strip() if not answer or len(answer) < 3: return self._fallback_response(query, context_docs) return answer except Exception as exc: logger.error("Generation error: %s", exc) return self._fallback_response(query, context_docs) def _fallback_response(self, query: str, context_docs: list[dict]) -> str: """Provide relevant excerpts when the LLM response is inadequate. Args: query: User question. context_docs: Retrieved documents. Returns: Formatted fallback response with source excerpts. """ if not context_docs: return "No relevant information found in the uploaded documents." response_parts = ["Here are the most relevant excerpts from your documents:\n"] for i, doc in enumerate(context_docs[:3], 1): excerpt = doc["text"][:300].strip() source = doc["source"] score = doc["similarity"] response_parts.append( f"**Source {i}** ({source}, relevance: {score:.0%}):\n> {excerpt}...\n" ) return "\n".join(response_parts) # --------------------------------------------------------------------------- # Streamlit UI Helpers # --------------------------------------------------------------------------- def init_session_state() -> None: """Initialize all required Streamlit session state variables.""" defaults = { "documents": [], "chat_history": [], "embedder": None, "llm_pipeline": None, "chroma_client": None, "collection": None, "rag_system": None, "models_loaded": False, "retrieved_docs": [], } for key, value in defaults.items(): if key not in st.session_state: st.session_state[key] = value def load_models_into_state() -> None: """Load models and vector store into session state if not already loaded.""" if st.session_state.models_loaded: return rag = RAGSystem() st.session_state.rag_system = rag try: llm, embedder = RAGSystem.load_models() st.session_state.llm_pipeline = llm st.session_state.embedder = embedder client, collection = rag.setup_vector_store() st.session_state.chroma_client = client st.session_state.collection = collection st.session_state.models_loaded = True logger.info("All models and vector store loaded successfully") except Exception as exc: st.error(f"Failed to load models: {exc}") logger.error("Model loading failed: %s", exc) def render_sidebar() -> None: """Render the sidebar with document upload and management.""" with st.sidebar: st.header("Document Management") st.markdown("---") # File uploader uploaded_files = st.file_uploader( "Upload Documents", type=["pdf", "txt", "docx", "csv"], accept_multiple_files=True, help="Supported formats: PDF, TXT, DOCX, CSV", ) if uploaded_files and st.button("Process Documents", type="primary"): process_uploaded_files(uploaded_files) st.markdown("---") # Document list st.subheader("Loaded Documents") if st.session_state.documents: for doc in st.session_state.documents: st.markdown( f'