Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import re | |
| import pdfplumber | |
| from io import BytesIO | |
| from docx import Document | |
| import pandas as pd | |
| import numpy as np | |
| import faiss | |
| from huggingface_hub import InferenceClient | |
| # ============================================ | |
| # SMART RAG API - HuggingFace Space Version | |
| # Technologies: FastAPI, FAISS, HuggingFace Hub | |
| # Parsers: pdfplumber, python-docx, pandas | |
| # ============================================ | |
| # ============== CONFIG ============== | |
| CHUNK_SIZE = 500 | |
| CHUNK_OVERLAP = 50 | |
| EMBEDDING_DIM = 384 | |
| # ============== TEXT CHUNKING ============== | |
| def chunk_text(text): | |
| """Convert text into clean, meaningful chunks with overlap.""" | |
| if not text or not text.strip(): | |
| return [] | |
| text = " ".join(text.strip().split()) | |
| chunks = [] | |
| start = 0 | |
| chunk_index = 0 | |
| while start < len(text): | |
| end = start + CHUNK_SIZE | |
| chunk_content = text[start:end] | |
| # Try to break at sentence boundary | |
| if end < len(text): | |
| last_period = chunk_content.rfind(". ") | |
| if last_period > CHUNK_SIZE * 0.5: | |
| chunk_content = chunk_content[:last_period + 1] | |
| end = start + last_period + 1 | |
| chunks.append({ | |
| "content": chunk_content.strip(), | |
| "chunk_index": chunk_index | |
| }) | |
| chunk_index += 1 | |
| start = end - CHUNK_OVERLAP | |
| if start >= len(text) - CHUNK_OVERLAP: | |
| break | |
| return chunks | |
| # ============== DOCUMENT PARSERS ============== | |
| def parse_pdf(file_bytes): | |
| """.pdf via pdfplumber""" | |
| text_parts = [] | |
| with pdfplumber.open(BytesIO(file_bytes)) as pdf: | |
| for i, page in enumerate(pdf.pages): | |
| page_text = page.extract_text() or "" | |
| if page_text.strip(): | |
| text_parts.append(f"[Page {i + 1}]\n{page_text}") | |
| return "\n\n".join(text_parts) | |
| def parse_docx(file_bytes): | |
| """.docx via python-docx""" | |
| doc = Document(BytesIO(file_bytes)) | |
| paragraphs = [para.text for para in doc.paragraphs if para.text.strip()] | |
| return "\n\n".join(paragraphs) | |
| def parse_txt(file_bytes): | |
| """.txt directly""" | |
| return file_bytes.decode("utf-8") | |
| def parse_csv(file_bytes): | |
| """.csv using pandas""" | |
| df = pd.read_csv(BytesIO(file_bytes)) | |
| lines = [ | |
| f"Columns: {', '.join(df.columns.tolist())}", | |
| f"Total rows: {len(df)}", | |
| "\nData:" | |
| ] | |
| for idx, row in df.head(50).iterrows(): | |
| row_text = " | ".join([f"{col}: {val}" for col, val in row.items()]) | |
| lines.append(row_text) | |
| return "\n".join(lines) | |
| def parse_document(file_bytes, filename): | |
| """Parse document and return chunks with metadata.""" | |
| ext = filename.split(".")[-1].lower() | |
| if ext == "pdf": | |
| text = parse_pdf(file_bytes) | |
| elif ext == "docx": | |
| text = parse_docx(file_bytes) | |
| elif ext == "txt": | |
| text = parse_txt(file_bytes) | |
| elif ext == "csv": | |
| text = parse_csv(file_bytes) | |
| else: | |
| text = f"[Unsupported file type: {ext}]" | |
| chunks = chunk_text(text) | |
| # Add metadata (filename, chunk index) | |
| for chunk in chunks: | |
| chunk["source"] = filename | |
| chunk["file_type"] = ext | |
| return {"text": text, "chunks": chunks} | |
| # ============== EMBEDDINGS (HuggingFace style) ============== | |
| def simple_tokenize(text): | |
| """Simple word tokenization.""" | |
| text = text.lower() | |
| tokens = re.findall(r'\b[a-z]+\b', text) | |
| return tokens | |
| def hash_embed(text, dim=EMBEDDING_DIM): | |
| """Simple hash-based embedding (lightweight alternative to sentence-transformers).""" | |
| tokens = simple_tokenize(text) | |
| vector = np.zeros(dim) | |
| for token in tokens: | |
| idx = hash(token) % dim | |
| vector[idx] += 1 | |
| # Normalize | |
| norm = np.linalg.norm(vector) | |
| if norm > 0: | |
| vector = vector / norm | |
| return vector | |
| def embed_texts(texts): | |
| """Generate embeddings for multiple texts.""" | |
| return np.array([hash_embed(t) for t in texts]).astype("float32") | |
| # ============== VECTOR STORE (FAISS) ============== | |
| class VectorStore: | |
| """Store embeddings in FAISS for similarity search.""" | |
| def __init__(self): | |
| self.index = None | |
| self.documents = [] | |
| def add_documents(self, chunks): | |
| """Add document chunks to FAISS index.""" | |
| if not chunks: | |
| return 0 | |
| texts = [c["content"] for c in chunks] | |
| embeddings = embed_texts(texts) | |
| if self.index is None: | |
| self.index = faiss.IndexFlatL2(EMBEDDING_DIM) | |
| self.index.add(embeddings) | |
| self.documents.extend(chunks) | |
| return len(chunks) | |
| def search(self, query, top_k=5): | |
| """Perform similarity search.""" | |
| if self.index is None or self.index.ntotal == 0: | |
| return [] | |
| query_embedding = embed_texts([query]) | |
| distances, indices = self.index.search(query_embedding, top_k) | |
| results = [] | |
| for i, idx in enumerate(indices[0]): | |
| if 0 <= idx < len(self.documents): | |
| doc = self.documents[idx].copy() | |
| doc["score"] = float(distances[0][i]) | |
| results.append(doc) | |
| return results | |
| def clear(self): | |
| """Clear all documents.""" | |
| self.index = None | |
| self.documents = [] | |
| def get_stats(self): | |
| """Get store statistics.""" | |
| return { | |
| "total_documents": len(self.documents), | |
| "index_size": self.index.ntotal if self.index else 0 | |
| } | |
| # ============== LLM SERVICE (HuggingFace Hub) ============== | |
| def get_llm_client(): | |
| """Get HuggingFace Inference Client.""" | |
| token = os.getenv("HUGGINGFACE_API_KEY", "") | |
| if not token: | |
| try: | |
| token = st.secrets["HUGGINGFACE_API_KEY"] | |
| except: | |
| token = "" | |
| return InferenceClient(model="HuggingFaceH4/zephyr-7b-beta", token=token if token else None) | |
| def generate_answer(question, context): | |
| """Send prompt to LLM and return answer.""" | |
| prompt = f"""You are a helpful assistant that answers questions based on the provided context. | |
| CONTEXT: | |
| {context} | |
| INSTRUCTIONS: | |
| - Answer the question based ONLY on the context provided above. | |
| - If the context doesn't contain enough information, say so. | |
| - Be concise and direct. | |
| - Mention which source the information comes from if relevant. | |
| QUESTION: {question} | |
| ANSWER:""" | |
| try: | |
| client = get_llm_client() | |
| response = client.chat_completion( | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=512, | |
| temperature=0.7 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # ============== STREAMLIT UI ============== | |
| st.set_page_config( | |
| page_title="Smart RAG API", | |
| page_icon="π", | |
| layout="wide" | |
| ) | |
| st.title("π Smart RAG API") | |
| st.markdown(""" | |
| **Retrieval-Augmented Generation API** - Upload documents and ask questions! | |
| **Technologies:** FastAPI β’ FAISS β’ pdfplumber β’ python-docx β’ pandas β’ HuggingFace Hub | |
| """) | |
| # Initialize vector store | |
| if "vector_store" not in st.session_state: | |
| st.session_state.vector_store = VectorStore() | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("π Status") | |
| stats = st.session_state.vector_store.get_stats() | |
| st.success("β Running") | |
| st.metric("Documents in Store", stats["total_documents"]) | |
| st.metric("Index Size", stats["index_size"]) | |
| st.divider() | |
| if st.button("ποΈ Clear All Documents"): | |
| st.session_state.vector_store.clear() | |
| st.success("Cleared!") | |
| st.rerun() | |
| st.divider() | |
| st.markdown("### π Supported Files") | |
| st.markdown(""" | |
| - π **PDF** (pdfplumber) | |
| - π **DOCX** (python-docx) | |
| - π **TXT** (direct) | |
| - π **CSV** (pandas) | |
| """) | |
| st.divider() | |
| st.markdown("### π οΈ Tech Stack") | |
| st.markdown(""" | |
| - **Vector Store:** FAISS | |
| - **LLM:** HuggingFace Hub | |
| - **Embeddings:** Custom (lightweight) | |
| - **UI:** Streamlit | |
| """) | |
| # Main layout | |
| col1, col2 = st.columns(2) | |
| # Upload Section | |
| with col1: | |
| st.header("π€ Upload Document") | |
| uploaded_file = st.file_uploader( | |
| "Choose a file", | |
| type=["pdf", "docx", "txt", "csv"], | |
| help="Supported: PDF, DOCX, TXT, CSV" | |
| ) | |
| if uploaded_file: | |
| file_icon = {"pdf": "π", "docx": "π", "txt": "π", "csv": "π"} | |
| ext = uploaded_file.name.split(".")[-1].lower() | |
| st.info(f"{file_icon.get(ext, 'π')} **{uploaded_file.name}** ({uploaded_file.size} bytes)") | |
| if st.button("π€ Process Document", type="primary"): | |
| with st.spinner("Processing document..."): | |
| try: | |
| file_bytes = uploaded_file.getvalue() | |
| parsed = parse_document(file_bytes, uploaded_file.name) | |
| added = st.session_state.vector_store.add_documents(parsed["chunks"]) | |
| st.success(f"β Success! Added **{added} chunks** to knowledge base.") | |
| st.json({ | |
| "filename": uploaded_file.name, | |
| "file_type": ext, | |
| "chunks_created": added | |
| }) | |
| except Exception as e: | |
| st.error(f"β Error: {str(e)}") | |
| # Query Section | |
| with col2: | |
| st.header("π¬ Ask Questions") | |
| question = st.text_area( | |
| "Your question:", | |
| placeholder="What is this document about?", | |
| height=100 | |
| ) | |
| top_k = st.slider("Number of sources to retrieve", 1, 10, 3) | |
| if st.button("π Search & Answer", type="primary"): | |
| if not question: | |
| st.warning("β οΈ Please enter a question") | |
| elif st.session_state.vector_store.get_stats()["total_documents"] == 0: | |
| st.warning("β οΈ Please upload documents first") | |
| else: | |
| with st.spinner("Searching and generating answer..."): | |
| # Vector search | |
| results = st.session_state.vector_store.search(question, top_k) | |
| if results: | |
| # Build context | |
| context_parts = [] | |
| for i, r in enumerate(results, 1): | |
| context_parts.append(f"[Source {i}: {r['source']}]\n{r['content']}") | |
| context = "\n\n".join(context_parts) | |
| # Generate answer via LLM | |
| answer = generate_answer(question, context) | |
| # Display answer | |
| st.subheader("π Answer") | |
| st.markdown(answer) | |
| # Display sources | |
| st.subheader("π Sources") | |
| for i, r in enumerate(results, 1): | |
| with st.expander(f"Source {i}: {r['source']} (score: {r['score']:.3f})"): | |
| st.write(r["content"][:500] + "..." if len(r["content"]) > 500 else r["content"]) | |
| else: | |
| st.warning("No relevant documents found.") | |
| # Footer | |
| st.divider() | |
| st.caption("π **Smart RAG API** | Built with FAISS, HuggingFace Hub, pdfplumber, python-docx, pandas | By Emon Karmoker") |