MultiModel-Rag / app.py
dev-models's picture
app file updated
0fe8565
import streamlit as st
import os
import base64
from io import BytesIO
from PIL import Image
import time
# Import Modular components
from backend.rag import RAGEngine
from backend.parser import EnrichedRagParser
import tempfile
# ==========================================
# 1. Page Configuration & Professional CSS
# ==========================================
st.set_page_config(
page_title="Multimodal RAG Assistant",
page_icon="πŸ€–",
layout="wide",
initial_sidebar_state="expanded"
)
# Production-ready CSS
st.markdown("""
<style>
.stChatMessage {
background-color: var(--secondary-background-color);
border: 1px solid rgba(128, 128, 128, 0.1);
border-radius: 12px;
padding: 1.5rem;
margin-bottom: 1rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
.stats-container {
background-color: var(--secondary-background-color);
border: 1px solid rgba(128, 128, 128, 0.2);
border-radius: 10px;
padding: 15px;
margin-top: 10px;
}
.stats-header {
font-weight: 600;
color: var(--text-color);
margin-bottom: 8px;
display: block;
}
.stats-item {
font-size: 0.9em;
color: var(--text-color);
opacity: 0.8;
margin-bottom: 4px;
display: flex;
justify-content: space-between;
}
</style>
""", unsafe_allow_html=True)
# ==========================================
# 2. Initialization & Helper Functions
# ==========================================
@st.cache_resource
def initialize_rag_system(force_clean: bool = True):
"""Initialize the RAG system with caching."""
return RAGEngine(use_hybrid=True, force_clean=force_clean)
def display_image_from_base64(base64_str: str, caption: str = "", width: int = 300):
"""Helper to decode and display base64 images."""
try:
img_data = base64.b64decode(base64_str)
img = Image.open(BytesIO(img_data))
st.image(img, caption=caption, width=width)
except Exception as e:
st.error(f"Failed to display image: {e}")
# ==========================================
# 3. Main Application
# ==========================================
def main():
# --- State Management ---
if "messages" not in st.session_state:
st.session_state.messages = []
if "suggested_questions" not in st.session_state:
st.session_state.suggested_questions = []
# Initialize Backend
if "rag" not in st.session_state:
with st.spinner("πŸš€ Booting up AI System..."):
st.session_state.rag = initialize_rag_system()
rag: RAGEngine = st.session_state.rag
# ==========================================
# SIDEBAR: Control Panel
# ==========================================
with st.sidebar:
st.header("🧠 RAG Control Panel")
# --- PDF Document Upload ---
with st.expander("πŸ“‚ Knowledge Base", expanded=True):
uploaded_file = st.file_uploader(
"Upload Document (PDF)",
type=["pdf"],
label_visibility="collapsed"
)
if uploaded_file:
# Temporary save for parsing
# temp_dir = "/tmp"
# os.makedirs(temp_dir, exist_ok=True)
# save_path = os.path.join(temp_dir, uploaded_file.name)
# with open(save_path, "wb") as f:
# f.write(uploaded_file.getbuffer())
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(uploaded_file.read())
file_path = tmp.name
if st.button("πŸš€ Process PDF", type="primary", use_container_width=True):
try:
with st.spinner("Analyzing PDF with Docling..."):
parser = EnrichedRagParser()
parsed_data = parser.process_document(file_path)
with st.spinner("Ingesting into MongoDB..."):
rag.ingest_data(parsed_data)
# Generate Suggestions
suggestions = rag.generate_suggested_questions(num_questions=6)
st.session_state.suggested_questions = suggestions
st.success(f"Processed: {uploaded_file.name}")
st.rerun()
except Exception as e:
st.error(f"❌ Error: {str(e)}")
finally:
# # βœ… Always cleanup temp file
# if os.path.exists(file_path):
# os.remove(file_path)
print("🧹 Temp file deleted")
st.rerun()
st.markdown("---")
# --- Suggested Questions ---
if st.session_state.suggested_questions:
st.subheader("πŸ’‘ Quick Questions")
for idx, q in enumerate(st.session_state.suggested_questions):
if st.button(q, key=f"sugg_{idx}", use_container_width=True):
st.session_state.messages.append({"role": "user", "content": q})
st.rerun()
st.markdown("---")
# --- Settings ---
with st.expander("βš™οΈ Search Settings"):
top_k = st.slider("Max Results", 1, 10, 5)
min_score = st.slider("Confidence Threshold", 0.0, 1.0, 0.6)
use_images = st.toggle("Enable Image Search", value=True)
# --- System Stats ---
count = rag.collection.count_documents({})
st.markdown(
f"""
<div class="stats-container">
<span class="stats-header">πŸ“Š Database Status</span>
<div class="stats-item"><span>Total Chunks:</span> <strong>{count}</strong></div>
<div class="stats-item"><span>Embedding:</span> <strong>CLIP ViT-L/14</strong></div>
</div>
""",
unsafe_allow_html=True,
)
# Reset
if st.button("πŸ—‘οΈ Clear Chat", type="secondary", use_container_width=True):
st.session_state.messages = []
st.rerun()
if st.button("⚠️ Delete Vector Collection", type="primary", use_container_width=True):
with st.spinner("Deleting collection..."):
rag.collection.delete_many({})
# Reset in-memory indices to match empty DB
rag.bm25_index = None
rag.bm25_doc_map = {}
st.success("Vector Collection Deleted!")
time.sleep(1) # Give user a moment to see the success message
st.rerun()
# ==========================================
# MAIN: Chat Interface
# ==========================================
st.title("πŸ€– Multimodal AI Assistant")
if not st.session_state.messages:
st.markdown(
"""
<div style="text-align: center; margin-top: 50px; opacity: 0.7;">
<h3>πŸ‘‹ Ready to help!</h3>
<p>Upload a PDF in the sidebar to start.</p>
</div>
""",
unsafe_allow_html=True,
)
# Render History
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if "images" in msg and msg["images"]:
st.markdown("---")
cols = st.columns(3)
for i, img in enumerate(msg["images"]):
with cols[i % 3]:
display_image_from_base64(img["image_base64"], width=220)
# ==========================================
# LOGIC: Input Handling
# ==========================================
user_input = st.chat_input("Type your question here...")
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
st.rerun()
# ==========================================
# ASSISTANT: Streaming Response Logic
# ==========================================
if st.session_state.messages and st.session_state.messages[-1]["role"] == "user":
last_query = st.session_state.messages[-1]["content"]
with st.chat_message("assistant"):
with st.spinner("πŸ€” Searching context..."):
try:
img_keywords = ["show", "image", "diagram", "figure", "picture"]
is_visual_request = any(
k in last_query.lower() for k in img_keywords
) and use_images
found_imgs = []
answer_text = ""
if is_visual_request:
# πŸ” Image search branch (non-streaming)
found_imgs = rag.search_images(
last_query,
top_k=3,
min_score=min_score,
)
if found_imgs:
answer_text = f"I found {len(found_imgs)} relevant visuals:"
else:
answer_text = "I couldn't find any relevant images."
# Render once
st.markdown(answer_text)
else:
# 🧠 Text answer branch (STREAMING)
# Assume rag.answer_question returns a generator / stream.
# st.write_stream will both display the chunks and return
# the final concatenated string.[web:60]
stream = rag.answer_question(
last_query,
top_k=top_k
)
answer_text = st.write_stream(stream)
# Render images if any
if found_imgs:
st.markdown("---")
cols = st.columns(3)
for idx, img in enumerate(found_imgs):
with cols[idx % 3]:
display_image_from_base64(
img["image_base64"], width=220
)
# Persist assistant message in history
st.session_state.messages.append(
{
"role": "assistant",
"content": answer_text,
"images": found_imgs,
}
)
except Exception as e:
st.error(f"Error: {e}")
st.session_state.messages.append(
{"role": "assistant", "content": f"❌ Error: {e}"}
)
if __name__ == "__main__":
main()