Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| import time | |
| import hashlib | |
| from pathlib import Path | |
| from streamlit import config | |
| # ─── Page Config ────────────────────────────────────────────────────────────── | |
| st.set_page_config( | |
| page_title="DocMind AI – Multimodal RAG", | |
| page_icon="🧠", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| config.set_option("server.enableCORS", False) | |
| config.set_option("server.enableXsrfProtection", False) | |
| MAX_FILES = 5 | |
| # ─── CSS ────────────────────────────────────────────────────────────────────── | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Sans:wght@300;400;500&display=swap'); | |
| html, body, [class*="css"] { font-family: 'DM Sans', sans-serif; } | |
| .stApp { background: #0f0f13; color: #e8e8f0; } | |
| [data-testid="stSidebar"] { background: #16161d !important; border-right: 1px solid #2a2a3a; } | |
| .hero-title { | |
| font-family: 'Syne', sans-serif; font-size: 2.8rem; font-weight: 800; | |
| background: linear-gradient(135deg, #7c6af7 0%, #a78bfa 40%, #38bdf8 100%); | |
| -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; | |
| line-height: 1.1; margin-bottom: 0.2rem; | |
| } | |
| .hero-sub { color: #6b6b8a; font-size: 1rem; font-weight: 300; letter-spacing: 0.04em; margin-bottom: 2rem; } | |
| .stat-card { background: #1c1c26; border: 1px solid #2a2a3a; border-radius: 12px; padding: 1rem 1.2rem; text-align: center; } | |
| .stat-number { font-family: 'Syne', sans-serif; font-size: 1.6rem; font-weight: 700; color: #a78bfa; } | |
| .stat-label { font-size: 0.75rem; color: #6b6b8a; text-transform: uppercase; letter-spacing: 0.08em; } | |
| .chat-user { | |
| background: #1e1e2e; border: 1px solid #2a2a3a; | |
| border-radius: 12px 12px 4px 12px; padding: 0.9rem 1.1rem; margin: 0.5rem 0; color: #e8e8f0; | |
| } | |
| .chat-assistant { | |
| background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); | |
| border: 1px solid #312e81; border-radius: 12px 12px 12px 4px; | |
| padding: 0.9rem 1.1rem; margin: 0.5rem 0; color: #e8e8f0; | |
| } | |
| .chat-label { font-size: 0.7rem; font-weight: 600; text-transform: uppercase; letter-spacing: 0.1em; margin-bottom: 0.4rem; } | |
| .label-user { color: #38bdf8; } | |
| .label-ai { color: #a78bfa; } | |
| .source-pill { | |
| display: inline-block; background: #1f1f2e; border: 1px solid #3730a3; | |
| border-radius: 20px; padding: 0.2rem 0.7rem; font-size: 0.72rem; color: #818cf8; margin: 0.2rem 0.15rem; | |
| } | |
| .memory-badge { | |
| display: inline-block; background: #1a2e1a; border: 1px solid #166534; | |
| border-radius: 20px; padding: 0.2rem 0.7rem; font-size: 0.7rem; color: #4ade80; margin-left: 0.5rem; | |
| } | |
| .filetype-badge { | |
| display: inline-block; padding: 2px 10px; border-radius: 12px; | |
| font-size: 0.72rem; font-weight: 600; text-transform: uppercase; letter-spacing: 0.05em; | |
| } | |
| .ft-pdf { background: #7f1d1d; color: #fca5a5; } | |
| .ft-image { background: #1e1b4b; color: #a5b4fc; } | |
| .ft-csv { background: #064e3b; color: #6ee7b7; } | |
| .ft-excel { background: #064e3b; color: #6ee7b7; } | |
| .ft-docx { background: #1e3a5f; color: #7dd3fc; } | |
| .ft-text { background: #1c1917; color: #d6d3d1; } | |
| .doc-item { | |
| background: #1c1c26; border: 1px solid #2a2a3a; border-radius: 10px; | |
| padding: 0.6rem 0.8rem; margin-bottom: 0.4rem; | |
| } | |
| [data-testid="stFileUploader"] { background: #1c1c26 !important; border: 2px dashed #2a2a3a !important; border-radius: 12px !important; } | |
| .stButton > button { | |
| background: linear-gradient(135deg, #7c3aed, #4f46e5) !important; | |
| color: white !important; border: none !important; border-radius: 8px !important; | |
| font-family: 'DM Sans', sans-serif !important; font-weight: 500 !important; | |
| } | |
| .stButton > button:hover { transform: translateY(-1px) !important; box-shadow: 0 4px 20px rgba(124,58,237,0.4) !important; } | |
| .stTextInput > div > div > input, [data-testid="stChatInputTextArea"] { | |
| background: #1c1c26 !important; border: 1px solid #2a2a3a !important; | |
| color: #e8e8f0 !important; border-radius: 10px !important; | |
| } | |
| .badge-ready { background:#14532d; color:#86efac; padding:3px 10px; border-radius:20px; font-size:0.75rem; } | |
| .badge-empty { background:#1c1917; color:#a8a29e; padding:3px 10px; border-radius:20px; font-size:0.75rem; } | |
| .badge-count { background:#312e81; color:#a5b4fc; padding:3px 10px; border-radius:20px; font-size:0.75rem; } | |
| hr { border-color: #2a2a3a !important; } | |
| ::-webkit-scrollbar { width: 6px; } | |
| ::-webkit-scrollbar-track { background: #0f0f13; } | |
| ::-webkit-scrollbar-thumb { background: #2a2a3a; border-radius: 3px; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ─── Cache RAG engine ───────────────────────────────────────────────────────── | |
| def load_rag_engine(): | |
| from rag_engine import RAGEngine | |
| return RAGEngine() | |
| # ─── Session state ──────────────────────────────────────────────────────────── | |
| defaults = { | |
| "messages": [], | |
| "processed_files": {}, # {filename: md5_hash} | |
| } | |
| for k, v in defaults.items(): | |
| if k not in st.session_state: | |
| st.session_state[k] = v | |
| def file_type_badge(suffix: str) -> str: | |
| m = { | |
| ".pdf": ("pdf", "PDF"), | |
| ".txt": ("text", "TXT"), | |
| ".docx": ("docx", "DOCX"), | |
| ".doc": ("docx", "DOC"), | |
| ".csv": ("csv", "CSV"), | |
| ".xlsx": ("excel", "XLSX"), | |
| ".xls": ("excel", "XLS"), | |
| ".jpg": ("image", "IMAGE"), | |
| ".jpeg": ("image", "IMAGE"), | |
| ".png": ("image", "IMAGE"), | |
| ".webp": ("image", "IMAGE"), | |
| } | |
| cls, label = m.get(suffix, ("text", suffix.upper())) | |
| return f'<span class="filetype-badge ft-{cls}">{label}</span>' | |
| def type_emoji(suffix: str) -> str: | |
| m = { | |
| ".pdf": "📄", ".txt": "📄", | |
| ".docx": "📝", ".doc": "📝", | |
| ".csv": "📊", ".xlsx": "📊", ".xls": "📊", | |
| ".jpg": "🖼️", ".jpeg": "🖼️", ".png": "🖼️", ".webp": "🖼️", | |
| } | |
| return m.get(suffix, "📄") | |
| # ─── Load RAG engine & get document state ───────────────────────────────────── | |
| rag = load_rag_engine() | |
| documents = rag.get_documents() # [{name, type, chunk_count}] | |
| doc_loaded = len(documents) > 0 | |
| total_chunks = rag.get_total_chunks() | |
| file_count = rag.get_file_count() | |
| # ─── Sidebar ────────────────────────────────────────────────────────────────── | |
| with st.sidebar: | |
| st.markdown('<p style="font-family:Syne,sans-serif;font-size:1.3rem;font-weight:700;color:#a78bfa;">🧠 DocMind AI</p>', unsafe_allow_html=True) | |
| st.markdown('<p style="color:#6b6b8a;font-size:0.8rem;">Multimodal RAG · Multi-File · Memory</p>', unsafe_allow_html=True) | |
| st.markdown("---") | |
| # ── Document List ───────────────────────────────────────────────────────── | |
| if documents: | |
| mem_count = rag.get_memory_count() | |
| st.markdown( | |
| f'<span class="badge-ready">✓ Ready</span> ' | |
| f'<span class="badge-count">{file_count}/{MAX_FILES} files</span>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown( | |
| f'<p style="color:#6b6b8a;font-size:0.78rem;margin-top:0.3rem;">' | |
| f'{total_chunks} total chunks · {mem_count} exchanges in memory</p>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("") | |
| # Show each document with a remove button | |
| for doc in documents: | |
| col_doc, col_rm = st.columns([5, 1]) | |
| with col_doc: | |
| badge = file_type_badge(doc["type"]) | |
| emoji = type_emoji(doc["type"]) | |
| st.markdown( | |
| f'<div class="doc-item">' | |
| f'{badge} <b style="color:#e8e8f0;font-size:0.82rem;">{doc["name"]}</b>' | |
| f'<br><span style="color:#6b6b8a;font-size:0.72rem;">' | |
| f'{emoji} {doc["chunk_count"]} chunks</span>' | |
| f'</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| with col_rm: | |
| st.markdown('<div style="padding-top:0.6rem;"></div>', unsafe_allow_html=True) | |
| if st.button("❌", key=f"rm_{doc['name']}", help=f"Remove {doc['name']}"): | |
| rag.remove_file(doc["name"]) | |
| # Remove from processed_files tracking | |
| st.session_state.processed_files = { | |
| k: v for k, v in st.session_state.processed_files.items() | |
| if k != doc["name"] | |
| } | |
| st.rerun() | |
| else: | |
| st.markdown('<span class="badge-empty">○ No documents loaded</span>', unsafe_allow_html=True) | |
| st.markdown("---") | |
| # ── Upload Area ─────────────────────────────────────────────────────────── | |
| st.markdown( | |
| '<p style="color:#6b6b8a;font-size:0.78rem;font-weight:600;text-transform:uppercase;letter-spacing:0.08em;">' | |
| 'Upload Document</p>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown( | |
| '<p style="color:#6b6b8a;font-size:0.72rem;">' | |
| 'PDF · TXT · DOCX · CSV · XLSX · JPG · PNG</p>', | |
| unsafe_allow_html=True, | |
| ) | |
| if file_count >= MAX_FILES: | |
| st.warning(f"Maximum {MAX_FILES} files reached. Remove a file to upload more.") | |
| uploaded_file = None | |
| else: | |
| uploaded_file = st.file_uploader( | |
| "Upload", | |
| type=["pdf", "txt", "docx", "doc", "csv", "xlsx", "xls", | |
| "jpg", "jpeg", "png", "webp"], | |
| label_visibility="collapsed", | |
| ) | |
| if uploaded_file: | |
| file_hash = hashlib.md5(uploaded_file.read()).hexdigest() | |
| uploaded_file.seek(0) | |
| # Check if this exact file (by hash) was already processed | |
| already_processed = file_hash in st.session_state.processed_files.values() | |
| if not already_processed: | |
| suffix = Path(uploaded_file.name).suffix.lower() | |
| type_msg = { | |
| ".pdf": "Reading PDF...", | |
| ".txt": "Reading text...", | |
| ".docx": "Reading Word doc...", | |
| ".csv": "Parsing CSV...", | |
| ".xlsx": "Parsing Excel...", | |
| ".xls": "Parsing Excel...", | |
| ".jpg": "🖼️ Processing image (OCR + captioning)...", | |
| ".jpeg": "🖼️ Processing image (OCR + captioning)...", | |
| ".png": "🖼️ Processing image (OCR + captioning)...", | |
| ".webp": "🖼️ Processing image (OCR + captioning)...", | |
| }.get(suffix, "Processing...") | |
| with st.spinner(type_msg): | |
| try: | |
| chunks = rag.ingest_file(uploaded_file) | |
| st.session_state.processed_files[uploaded_file.name] = file_hash | |
| st.success(f"✓ Indexed {chunks} chunks from {uploaded_file.name}!") | |
| st.rerun() | |
| except ValueError as e: | |
| st.error(str(e)) | |
| except Exception as e: | |
| st.error(f"Failed to process file: {e}") | |
| st.markdown("---") | |
| # ── Sample doc ──────────────────────────────────────────────────────────── | |
| st.markdown( | |
| '<p style="color:#6b6b8a;font-size:0.78rem;font-weight:600;text-transform:uppercase;letter-spacing:0.08em;">' | |
| 'Or try a sample</p>', | |
| unsafe_allow_html=True, | |
| ) | |
| if st.button("📥 Load Sample: AI Report", use_container_width=True): | |
| if file_count >= MAX_FILES: | |
| st.error(f"Maximum {MAX_FILES} files reached. Remove a file first.") | |
| else: | |
| with st.spinner("Downloading sample..."): | |
| from data_downloader import download_sample_doc | |
| path, name = download_sample_doc() | |
| try: | |
| chunks = rag.ingest_path(path, name) | |
| st.session_state.processed_files[name] = "sample" | |
| st.success(f"✓ {chunks} chunks loaded!") | |
| st.rerun() | |
| except ValueError as e: | |
| st.error(str(e)) | |
| st.markdown("---") | |
| # ── Action buttons ──────────────────────────────────────────────────────── | |
| col_a, col_b = st.columns(2) | |
| with col_a: | |
| if st.button("🗑️ Clear Chat", use_container_width=True): | |
| st.session_state.messages = [] | |
| rag.clear_memory() | |
| st.rerun() | |
| with col_b: | |
| if st.button("🔄 Reset All", use_container_width=True): | |
| rag.reset() | |
| st.session_state.messages = [] | |
| st.session_state.processed_files = {} | |
| st.rerun() | |
| st.markdown("---") | |
| st.markdown(""" | |
| <p style="color:#6b6b8a;font-size:0.72rem;line-height:1.8;"> | |
| <b style="color:#a78bfa;">Stack</b><br> | |
| 🔗 LangChain · ChromaDB<br> | |
| 🤗 MiniLM Embeddings<br> | |
| 🦙 Llama-3 / Mistral-7B<br> | |
| 🖼️ BLIP + VLM Captioning<br> | |
| 💬 Conversation Memory<br> | |
| 📁 Up to 5 files simultaneously<br> | |
| 🌊 Streamlit + FastAPI | |
| </p> | |
| """, unsafe_allow_html=True) | |
| # ─── Main Area ──────────────────────────────────────────────────────────────── | |
| st.markdown('<h1 class="hero-title">DocMind AI</h1>', unsafe_allow_html=True) | |
| st.markdown( | |
| '<p class="hero-sub">' | |
| 'PDF · Word · CSV · Excel · Images — Upload up to 5 files. Ask anything. Remembers your conversation.' | |
| '</p>', | |
| unsafe_allow_html=True, | |
| ) | |
| # ── Stats ───────────────────────────────────────────────────────────────────── | |
| c1, c2, c3, c4 = st.columns(4) | |
| with c1: | |
| st.markdown( | |
| f'<div class="stat-card">' | |
| f'<div class="stat-number">{total_chunks or "—"}</div>' | |
| f'<div class="stat-label">Chunks Indexed</div></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| with c2: | |
| st.markdown( | |
| f'<div class="stat-card">' | |
| f'<div class="stat-number">{file_count}/{MAX_FILES}</div>' | |
| f'<div class="stat-label">Files Loaded</div></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| with c3: | |
| st.markdown( | |
| f'<div class="stat-card">' | |
| f'<div class="stat-number">{len(st.session_state.messages) // 2}</div>' | |
| f'<div class="stat-label">Questions Asked</div></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| with c4: | |
| st.markdown( | |
| f'<div class="stat-card">' | |
| f'<div class="stat-number">{rag.get_memory_count()}</div>' | |
| f'<div class="stat-label">Memory Window</div></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| # ─── Chat history ───────────────────────────────────────────────────────────── | |
| if not st.session_state.messages: | |
| if doc_loaded: | |
| # Show loaded files summary | |
| file_names = ", ".join(f"<b style='color:#e8e8f0;'>{d['name']}</b>" for d in documents) | |
| emojis = " ".join(set(type_emoji(d["type"]) for d in documents)) | |
| st.markdown(f""" | |
| <div style="text-align:center;padding:3rem;color:#6b6b8a;"> | |
| <div style="font-size:2.5rem;margin-bottom:1rem;">{emojis}</div> | |
| <p style="font-size:1rem;color:#a78bfa;"> | |
| {file_count} document{'s' if file_count > 1 else ''} ready! | |
| </p> | |
| <p style="font-size:0.85rem;">Ask anything about {file_names}</p> | |
| <p style="font-size:0.78rem;margin-top:0.5rem;"> | |
| I'll remember your conversation — ask follow-up questions naturally. | |
| {'You can also upload more files (up to 5).' if file_count < MAX_FILES else ''} | |
| </p> | |
| </div>""", unsafe_allow_html=True) | |
| else: | |
| st.markdown(""" | |
| <div style="text-align:center;padding:4rem 2rem;color:#6b6b8a;"> | |
| <div style="font-size:3rem;margin-bottom:1rem;">🧠</div> | |
| <p style="font-size:1.1rem;color:#a78bfa;font-family:'Syne',sans-serif;font-weight:600;"> | |
| Multimodal RAG — Upload up to 5 files | |
| </p> | |
| <p style="font-size:0.85rem;margin-top:0.5rem;"> | |
| 📄 PDF · 📝 Word · 📊 CSV/Excel · 🖼️ Images<br><br> | |
| Upload in the sidebar or load the sample AI report to get started.<br> | |
| You can upload multiple files and ask questions across all of them. | |
| </p> | |
| </div>""", unsafe_allow_html=True) | |
| else: | |
| for msg in st.session_state.messages: | |
| if msg["role"] == "user": | |
| st.markdown(f""" | |
| <div class="chat-user"> | |
| <div class="chat-label label-user">You</div> | |
| {msg["content"]} | |
| </div>""", unsafe_allow_html=True) | |
| else: | |
| mem = msg.get("memory_count", 0) | |
| mem_badge = f'<span class="memory-badge">💬 {mem} in memory</span>' if mem > 0 else "" | |
| sources_html = "" | |
| if msg.get("sources"): | |
| pills = "".join(f'<span class="source-pill">📎 {s}</span>' for s in msg["sources"]) | |
| sources_html = f'<div style="margin-top:0.7rem;">{pills}</div>' | |
| st.markdown(f""" | |
| <div class="chat-assistant"> | |
| <div class="chat-label label-ai">DocMind AI {mem_badge}</div> | |
| {msg["content"]} | |
| {sources_html} | |
| </div>""", unsafe_allow_html=True) | |
| # ─── Chat Input ─────────────────────────────────────────────────────────────── | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| if not doc_loaded: | |
| st.chat_input("Upload a document first...", disabled=True) | |
| else: | |
| # Build a placeholder based on loaded file types | |
| loaded_types = set(d["type"] for d in documents) | |
| image_exts = {".jpg", ".jpeg", ".png", ".webp"} | |
| table_exts = {".csv", ".xlsx", ".xls"} | |
| if file_count == 1: | |
| doc_type = documents[0]["type"] | |
| placeholder = { | |
| ".pdf": "Ask anything about this PDF...", | |
| ".txt": "Ask anything about this text...", | |
| ".docx": "Ask anything about this document...", | |
| ".doc": "Ask anything about this document...", | |
| ".csv": "Ask about the data, columns, or statistics...", | |
| ".xlsx": "Ask about the spreadsheet data...", | |
| ".xls": "Ask about the spreadsheet data...", | |
| ".jpg": "Ask me what I see in this image...", | |
| ".jpeg": "Ask me what I see in this image...", | |
| ".png": "Ask me what I see in this image...", | |
| ".webp": "Ask me what I see in this image...", | |
| }.get(doc_type, "Ask anything about your document...") | |
| else: | |
| placeholder = f"Ask anything about your {file_count} documents..." | |
| if prompt := st.chat_input(placeholder): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.spinner("🔍 Retrieving & generating..."): | |
| answer, sources = rag.query(prompt) | |
| mem_count = rag.get_memory_count() | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": answer, | |
| "sources": sources, | |
| "memory_count": mem_count, | |
| }) | |
| st.rerun() | |