docmind-ai / app.py
Ryanfafa's picture
Update app.py
cd4a662 verified
import streamlit as st
import os
import time
import hashlib
from pathlib import Path
from streamlit import config
# ─── Page Config ──────────────────────────────────────────────────────────────
st.set_page_config(
page_title="DocMind AI – Multimodal RAG",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded",
)
config.set_option("server.enableCORS", False)
config.set_option("server.enableXsrfProtection", False)
MAX_FILES = 5
# ─── CSS ──────────────────────────────────────────────────────────────────────
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Sans:wght@300;400;500&display=swap');
html, body, [class*="css"] { font-family: 'DM Sans', sans-serif; }
.stApp { background: #0f0f13; color: #e8e8f0; }
[data-testid="stSidebar"] { background: #16161d !important; border-right: 1px solid #2a2a3a; }
.hero-title {
font-family: 'Syne', sans-serif; font-size: 2.8rem; font-weight: 800;
background: linear-gradient(135deg, #7c6af7 0%, #a78bfa 40%, #38bdf8 100%);
-webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text;
line-height: 1.1; margin-bottom: 0.2rem;
}
.hero-sub { color: #6b6b8a; font-size: 1rem; font-weight: 300; letter-spacing: 0.04em; margin-bottom: 2rem; }
.stat-card { background: #1c1c26; border: 1px solid #2a2a3a; border-radius: 12px; padding: 1rem 1.2rem; text-align: center; }
.stat-number { font-family: 'Syne', sans-serif; font-size: 1.6rem; font-weight: 700; color: #a78bfa; }
.stat-label { font-size: 0.75rem; color: #6b6b8a; text-transform: uppercase; letter-spacing: 0.08em; }
.chat-user {
background: #1e1e2e; border: 1px solid #2a2a3a;
border-radius: 12px 12px 4px 12px; padding: 0.9rem 1.1rem; margin: 0.5rem 0; color: #e8e8f0;
}
.chat-assistant {
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
border: 1px solid #312e81; border-radius: 12px 12px 12px 4px;
padding: 0.9rem 1.1rem; margin: 0.5rem 0; color: #e8e8f0;
}
.chat-label { font-size: 0.7rem; font-weight: 600; text-transform: uppercase; letter-spacing: 0.1em; margin-bottom: 0.4rem; }
.label-user { color: #38bdf8; }
.label-ai { color: #a78bfa; }
.source-pill {
display: inline-block; background: #1f1f2e; border: 1px solid #3730a3;
border-radius: 20px; padding: 0.2rem 0.7rem; font-size: 0.72rem; color: #818cf8; margin: 0.2rem 0.15rem;
}
.memory-badge {
display: inline-block; background: #1a2e1a; border: 1px solid #166534;
border-radius: 20px; padding: 0.2rem 0.7rem; font-size: 0.7rem; color: #4ade80; margin-left: 0.5rem;
}
.filetype-badge {
display: inline-block; padding: 2px 10px; border-radius: 12px;
font-size: 0.72rem; font-weight: 600; text-transform: uppercase; letter-spacing: 0.05em;
}
.ft-pdf { background: #7f1d1d; color: #fca5a5; }
.ft-image { background: #1e1b4b; color: #a5b4fc; }
.ft-csv { background: #064e3b; color: #6ee7b7; }
.ft-excel { background: #064e3b; color: #6ee7b7; }
.ft-docx { background: #1e3a5f; color: #7dd3fc; }
.ft-text { background: #1c1917; color: #d6d3d1; }
.doc-item {
background: #1c1c26; border: 1px solid #2a2a3a; border-radius: 10px;
padding: 0.6rem 0.8rem; margin-bottom: 0.4rem;
}
[data-testid="stFileUploader"] { background: #1c1c26 !important; border: 2px dashed #2a2a3a !important; border-radius: 12px !important; }
.stButton > button {
background: linear-gradient(135deg, #7c3aed, #4f46e5) !important;
color: white !important; border: none !important; border-radius: 8px !important;
font-family: 'DM Sans', sans-serif !important; font-weight: 500 !important;
}
.stButton > button:hover { transform: translateY(-1px) !important; box-shadow: 0 4px 20px rgba(124,58,237,0.4) !important; }
.stTextInput > div > div > input, [data-testid="stChatInputTextArea"] {
background: #1c1c26 !important; border: 1px solid #2a2a3a !important;
color: #e8e8f0 !important; border-radius: 10px !important;
}
.badge-ready { background:#14532d; color:#86efac; padding:3px 10px; border-radius:20px; font-size:0.75rem; }
.badge-empty { background:#1c1917; color:#a8a29e; padding:3px 10px; border-radius:20px; font-size:0.75rem; }
.badge-count { background:#312e81; color:#a5b4fc; padding:3px 10px; border-radius:20px; font-size:0.75rem; }
hr { border-color: #2a2a3a !important; }
::-webkit-scrollbar { width: 6px; }
::-webkit-scrollbar-track { background: #0f0f13; }
::-webkit-scrollbar-thumb { background: #2a2a3a; border-radius: 3px; }
</style>
""", unsafe_allow_html=True)
# ─── Cache RAG engine ─────────────────────────────────────────────────────────
@st.cache_resource(show_spinner=False)
def load_rag_engine():
from rag_engine import RAGEngine
return RAGEngine()
# ─── Session state ────────────────────────────────────────────────────────────
defaults = {
"messages": [],
"processed_files": {}, # {filename: md5_hash}
}
for k, v in defaults.items():
if k not in st.session_state:
st.session_state[k] = v
def file_type_badge(suffix: str) -> str:
m = {
".pdf": ("pdf", "PDF"),
".txt": ("text", "TXT"),
".docx": ("docx", "DOCX"),
".doc": ("docx", "DOC"),
".csv": ("csv", "CSV"),
".xlsx": ("excel", "XLSX"),
".xls": ("excel", "XLS"),
".jpg": ("image", "IMAGE"),
".jpeg": ("image", "IMAGE"),
".png": ("image", "IMAGE"),
".webp": ("image", "IMAGE"),
}
cls, label = m.get(suffix, ("text", suffix.upper()))
return f'<span class="filetype-badge ft-{cls}">{label}</span>'
def type_emoji(suffix: str) -> str:
m = {
".pdf": "📄", ".txt": "📄",
".docx": "📝", ".doc": "📝",
".csv": "📊", ".xlsx": "📊", ".xls": "📊",
".jpg": "🖼️", ".jpeg": "🖼️", ".png": "🖼️", ".webp": "🖼️",
}
return m.get(suffix, "📄")
# ─── Load RAG engine & get document state ─────────────────────────────────────
rag = load_rag_engine()
documents = rag.get_documents() # [{name, type, chunk_count}]
doc_loaded = len(documents) > 0
total_chunks = rag.get_total_chunks()
file_count = rag.get_file_count()
# ─── Sidebar ──────────────────────────────────────────────────────────────────
with st.sidebar:
st.markdown('<p style="font-family:Syne,sans-serif;font-size:1.3rem;font-weight:700;color:#a78bfa;">🧠 DocMind AI</p>', unsafe_allow_html=True)
st.markdown('<p style="color:#6b6b8a;font-size:0.8rem;">Multimodal RAG · Multi-File · Memory</p>', unsafe_allow_html=True)
st.markdown("---")
# ── Document List ─────────────────────────────────────────────────────────
if documents:
mem_count = rag.get_memory_count()
st.markdown(
f'<span class="badge-ready">✓ Ready</span> '
f'<span class="badge-count">{file_count}/{MAX_FILES} files</span>',
unsafe_allow_html=True,
)
st.markdown(
f'<p style="color:#6b6b8a;font-size:0.78rem;margin-top:0.3rem;">'
f'{total_chunks} total chunks · {mem_count} exchanges in memory</p>',
unsafe_allow_html=True,
)
st.markdown("")
# Show each document with a remove button
for doc in documents:
col_doc, col_rm = st.columns([5, 1])
with col_doc:
badge = file_type_badge(doc["type"])
emoji = type_emoji(doc["type"])
st.markdown(
f'<div class="doc-item">'
f'{badge} <b style="color:#e8e8f0;font-size:0.82rem;">{doc["name"]}</b>'
f'<br><span style="color:#6b6b8a;font-size:0.72rem;">'
f'{emoji} {doc["chunk_count"]} chunks</span>'
f'</div>',
unsafe_allow_html=True,
)
with col_rm:
st.markdown('<div style="padding-top:0.6rem;"></div>', unsafe_allow_html=True)
if st.button("❌", key=f"rm_{doc['name']}", help=f"Remove {doc['name']}"):
rag.remove_file(doc["name"])
# Remove from processed_files tracking
st.session_state.processed_files = {
k: v for k, v in st.session_state.processed_files.items()
if k != doc["name"]
}
st.rerun()
else:
st.markdown('<span class="badge-empty">○ No documents loaded</span>', unsafe_allow_html=True)
st.markdown("---")
# ── Upload Area ───────────────────────────────────────────────────────────
st.markdown(
'<p style="color:#6b6b8a;font-size:0.78rem;font-weight:600;text-transform:uppercase;letter-spacing:0.08em;">'
'Upload Document</p>',
unsafe_allow_html=True,
)
st.markdown(
'<p style="color:#6b6b8a;font-size:0.72rem;">'
'PDF · TXT · DOCX · CSV · XLSX · JPG · PNG</p>',
unsafe_allow_html=True,
)
if file_count >= MAX_FILES:
st.warning(f"Maximum {MAX_FILES} files reached. Remove a file to upload more.")
uploaded_file = None
else:
uploaded_file = st.file_uploader(
"Upload",
type=["pdf", "txt", "docx", "doc", "csv", "xlsx", "xls",
"jpg", "jpeg", "png", "webp"],
label_visibility="collapsed",
)
if uploaded_file:
file_hash = hashlib.md5(uploaded_file.read()).hexdigest()
uploaded_file.seek(0)
# Check if this exact file (by hash) was already processed
already_processed = file_hash in st.session_state.processed_files.values()
if not already_processed:
suffix = Path(uploaded_file.name).suffix.lower()
type_msg = {
".pdf": "Reading PDF...",
".txt": "Reading text...",
".docx": "Reading Word doc...",
".csv": "Parsing CSV...",
".xlsx": "Parsing Excel...",
".xls": "Parsing Excel...",
".jpg": "🖼️ Processing image (OCR + captioning)...",
".jpeg": "🖼️ Processing image (OCR + captioning)...",
".png": "🖼️ Processing image (OCR + captioning)...",
".webp": "🖼️ Processing image (OCR + captioning)...",
}.get(suffix, "Processing...")
with st.spinner(type_msg):
try:
chunks = rag.ingest_file(uploaded_file)
st.session_state.processed_files[uploaded_file.name] = file_hash
st.success(f"✓ Indexed {chunks} chunks from {uploaded_file.name}!")
st.rerun()
except ValueError as e:
st.error(str(e))
except Exception as e:
st.error(f"Failed to process file: {e}")
st.markdown("---")
# ── Sample doc ────────────────────────────────────────────────────────────
st.markdown(
'<p style="color:#6b6b8a;font-size:0.78rem;font-weight:600;text-transform:uppercase;letter-spacing:0.08em;">'
'Or try a sample</p>',
unsafe_allow_html=True,
)
if st.button("📥 Load Sample: AI Report", use_container_width=True):
if file_count >= MAX_FILES:
st.error(f"Maximum {MAX_FILES} files reached. Remove a file first.")
else:
with st.spinner("Downloading sample..."):
from data_downloader import download_sample_doc
path, name = download_sample_doc()
try:
chunks = rag.ingest_path(path, name)
st.session_state.processed_files[name] = "sample"
st.success(f"✓ {chunks} chunks loaded!")
st.rerun()
except ValueError as e:
st.error(str(e))
st.markdown("---")
# ── Action buttons ────────────────────────────────────────────────────────
col_a, col_b = st.columns(2)
with col_a:
if st.button("🗑️ Clear Chat", use_container_width=True):
st.session_state.messages = []
rag.clear_memory()
st.rerun()
with col_b:
if st.button("🔄 Reset All", use_container_width=True):
rag.reset()
st.session_state.messages = []
st.session_state.processed_files = {}
st.rerun()
st.markdown("---")
st.markdown("""
<p style="color:#6b6b8a;font-size:0.72rem;line-height:1.8;">
<b style="color:#a78bfa;">Stack</b><br>
🔗 LangChain · ChromaDB<br>
🤗 MiniLM Embeddings<br>
🦙 Llama-3 / Mistral-7B<br>
🖼️ BLIP + VLM Captioning<br>
💬 Conversation Memory<br>
📁 Up to 5 files simultaneously<br>
🌊 Streamlit + FastAPI
</p>
""", unsafe_allow_html=True)
# ─── Main Area ────────────────────────────────────────────────────────────────
st.markdown('<h1 class="hero-title">DocMind AI</h1>', unsafe_allow_html=True)
st.markdown(
'<p class="hero-sub">'
'PDF · Word · CSV · Excel · Images — Upload up to 5 files. Ask anything. Remembers your conversation.'
'</p>',
unsafe_allow_html=True,
)
# ── Stats ─────────────────────────────────────────────────────────────────────
c1, c2, c3, c4 = st.columns(4)
with c1:
st.markdown(
f'<div class="stat-card">'
f'<div class="stat-number">{total_chunks or "—"}</div>'
f'<div class="stat-label">Chunks Indexed</div></div>',
unsafe_allow_html=True,
)
with c2:
st.markdown(
f'<div class="stat-card">'
f'<div class="stat-number">{file_count}/{MAX_FILES}</div>'
f'<div class="stat-label">Files Loaded</div></div>',
unsafe_allow_html=True,
)
with c3:
st.markdown(
f'<div class="stat-card">'
f'<div class="stat-number">{len(st.session_state.messages) // 2}</div>'
f'<div class="stat-label">Questions Asked</div></div>',
unsafe_allow_html=True,
)
with c4:
st.markdown(
f'<div class="stat-card">'
f'<div class="stat-number">{rag.get_memory_count()}</div>'
f'<div class="stat-label">Memory Window</div></div>',
unsafe_allow_html=True,
)
st.markdown("<br>", unsafe_allow_html=True)
# ─── Chat history ─────────────────────────────────────────────────────────────
if not st.session_state.messages:
if doc_loaded:
# Show loaded files summary
file_names = ", ".join(f"<b style='color:#e8e8f0;'>{d['name']}</b>" for d in documents)
emojis = " ".join(set(type_emoji(d["type"]) for d in documents))
st.markdown(f"""
<div style="text-align:center;padding:3rem;color:#6b6b8a;">
<div style="font-size:2.5rem;margin-bottom:1rem;">{emojis}</div>
<p style="font-size:1rem;color:#a78bfa;">
{file_count} document{'s' if file_count > 1 else ''} ready!
</p>
<p style="font-size:0.85rem;">Ask anything about {file_names}</p>
<p style="font-size:0.78rem;margin-top:0.5rem;">
I'll remember your conversation — ask follow-up questions naturally.
{'You can also upload more files (up to 5).' if file_count < MAX_FILES else ''}
</p>
</div>""", unsafe_allow_html=True)
else:
st.markdown("""
<div style="text-align:center;padding:4rem 2rem;color:#6b6b8a;">
<div style="font-size:3rem;margin-bottom:1rem;">🧠</div>
<p style="font-size:1.1rem;color:#a78bfa;font-family:'Syne',sans-serif;font-weight:600;">
Multimodal RAG — Upload up to 5 files
</p>
<p style="font-size:0.85rem;margin-top:0.5rem;">
📄 PDF &nbsp;·&nbsp; 📝 Word &nbsp;·&nbsp; 📊 CSV/Excel &nbsp;·&nbsp; 🖼️ Images<br><br>
Upload in the sidebar or load the sample AI report to get started.<br>
You can upload multiple files and ask questions across all of them.
</p>
</div>""", unsafe_allow_html=True)
else:
for msg in st.session_state.messages:
if msg["role"] == "user":
st.markdown(f"""
<div class="chat-user">
<div class="chat-label label-user">You</div>
{msg["content"]}
</div>""", unsafe_allow_html=True)
else:
mem = msg.get("memory_count", 0)
mem_badge = f'<span class="memory-badge">💬 {mem} in memory</span>' if mem > 0 else ""
sources_html = ""
if msg.get("sources"):
pills = "".join(f'<span class="source-pill">📎 {s}</span>' for s in msg["sources"])
sources_html = f'<div style="margin-top:0.7rem;">{pills}</div>'
st.markdown(f"""
<div class="chat-assistant">
<div class="chat-label label-ai">DocMind AI {mem_badge}</div>
{msg["content"]}
{sources_html}
</div>""", unsafe_allow_html=True)
# ─── Chat Input ───────────────────────────────────────────────────────────────
st.markdown("<br>", unsafe_allow_html=True)
if not doc_loaded:
st.chat_input("Upload a document first...", disabled=True)
else:
# Build a placeholder based on loaded file types
loaded_types = set(d["type"] for d in documents)
image_exts = {".jpg", ".jpeg", ".png", ".webp"}
table_exts = {".csv", ".xlsx", ".xls"}
if file_count == 1:
doc_type = documents[0]["type"]
placeholder = {
".pdf": "Ask anything about this PDF...",
".txt": "Ask anything about this text...",
".docx": "Ask anything about this document...",
".doc": "Ask anything about this document...",
".csv": "Ask about the data, columns, or statistics...",
".xlsx": "Ask about the spreadsheet data...",
".xls": "Ask about the spreadsheet data...",
".jpg": "Ask me what I see in this image...",
".jpeg": "Ask me what I see in this image...",
".png": "Ask me what I see in this image...",
".webp": "Ask me what I see in this image...",
}.get(doc_type, "Ask anything about your document...")
else:
placeholder = f"Ask anything about your {file_count} documents..."
if prompt := st.chat_input(placeholder):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.spinner("🔍 Retrieving & generating..."):
answer, sources = rag.query(prompt)
mem_count = rag.get_memory_count()
st.session_state.messages.append({
"role": "assistant",
"content": answer,
"sources": sources,
"memory_count": mem_count,
})
st.rerun()