import streamlit as st import os import base64 from io import BytesIO from PIL import Image import time # Import Modular components from backend.rag import RAGEngine from backend.parser import EnrichedRagParser import tempfile # ========================================== # 1. Page Configuration & Professional CSS # ========================================== st.set_page_config( page_title="Multimodal RAG Assistant", page_icon="๐Ÿค–", layout="wide", initial_sidebar_state="expanded" ) # Production-ready CSS st.markdown(""" """, unsafe_allow_html=True) # ========================================== # 2. Initialization & Helper Functions # ========================================== @st.cache_resource def initialize_rag_system(force_clean: bool = True): """Initialize the RAG system with caching.""" return RAGEngine(use_hybrid=True, force_clean=force_clean) def display_image_from_base64(base64_str: str, caption: str = "", width: int = 300): """Helper to decode and display base64 images.""" try: img_data = base64.b64decode(base64_str) img = Image.open(BytesIO(img_data)) st.image(img, caption=caption, width=width) except Exception as e: st.error(f"Failed to display image: {e}") # ========================================== # 3. Main Application # ========================================== def main(): # --- State Management --- if "messages" not in st.session_state: st.session_state.messages = [] if "suggested_questions" not in st.session_state: st.session_state.suggested_questions = [] # Initialize Backend if "rag" not in st.session_state: with st.spinner("๐Ÿš€ Booting up AI System..."): st.session_state.rag = initialize_rag_system() rag: RAGEngine = st.session_state.rag # ========================================== # SIDEBAR: Control Panel # ========================================== with st.sidebar: st.header("๐Ÿง  RAG Control Panel") # --- PDF Document Upload --- with st.expander("๐Ÿ“‚ Knowledge Base", expanded=True): uploaded_file = st.file_uploader( "Upload Document (PDF)", type=["pdf"], label_visibility="collapsed" ) if uploaded_file: # Temporary save for parsing # temp_dir = "/tmp" # os.makedirs(temp_dir, exist_ok=True) # save_path = os.path.join(temp_dir, uploaded_file.name) # with open(save_path, "wb") as f: # f.write(uploaded_file.getbuffer()) with tempfile.NamedTemporaryFile(delete=False) as tmp: tmp.write(uploaded_file.read()) file_path = tmp.name if st.button("๐Ÿš€ Process PDF", type="primary", use_container_width=True): try: with st.spinner("Analyzing PDF with Docling..."): parser = EnrichedRagParser() parsed_data = parser.process_document(file_path) with st.spinner("Ingesting into MongoDB..."): rag.ingest_data(parsed_data) # Generate Suggestions suggestions = rag.generate_suggested_questions(num_questions=6) st.session_state.suggested_questions = suggestions st.success(f"Processed: {uploaded_file.name}") st.rerun() except Exception as e: st.error(f"โŒ Error: {str(e)}") finally: # # โœ… Always cleanup temp file # if os.path.exists(file_path): # os.remove(file_path) print("๐Ÿงน Temp file deleted") st.rerun() st.markdown("---") # --- Suggested Questions --- if st.session_state.suggested_questions: st.subheader("๐Ÿ’ก Quick Questions") for idx, q in enumerate(st.session_state.suggested_questions): if st.button(q, key=f"sugg_{idx}", use_container_width=True): st.session_state.messages.append({"role": "user", "content": q}) st.rerun() st.markdown("---") # --- Settings --- with st.expander("โš™๏ธ Search Settings"): top_k = st.slider("Max Results", 1, 10, 5) min_score = st.slider("Confidence Threshold", 0.0, 1.0, 0.6) use_images = st.toggle("Enable Image Search", value=True) # --- System Stats --- count = rag.collection.count_documents({}) st.markdown( f"""
๐Ÿ“Š Database Status
Total Chunks: {count}
Embedding: CLIP ViT-L/14
""", unsafe_allow_html=True, ) # Reset if st.button("๐Ÿ—‘๏ธ Clear Chat", type="secondary", use_container_width=True): st.session_state.messages = [] st.rerun() if st.button("โš ๏ธ Delete Vector Collection", type="primary", use_container_width=True): with st.spinner("Deleting collection..."): rag.collection.delete_many({}) # Reset in-memory indices to match empty DB rag.bm25_index = None rag.bm25_doc_map = {} st.success("Vector Collection Deleted!") time.sleep(1) # Give user a moment to see the success message st.rerun() # ========================================== # MAIN: Chat Interface # ========================================== st.title("๐Ÿค– Multimodal AI Assistant") if not st.session_state.messages: st.markdown( """

๐Ÿ‘‹ Ready to help!

Upload a PDF in the sidebar to start.

""", unsafe_allow_html=True, ) # Render History for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) if "images" in msg and msg["images"]: st.markdown("---") cols = st.columns(3) for i, img in enumerate(msg["images"]): with cols[i % 3]: display_image_from_base64(img["image_base64"], width=220) # ========================================== # LOGIC: Input Handling # ========================================== user_input = st.chat_input("Type your question here...") if user_input: st.session_state.messages.append({"role": "user", "content": user_input}) st.rerun() # ========================================== # ASSISTANT: Streaming Response Logic # ========================================== if st.session_state.messages and st.session_state.messages[-1]["role"] == "user": last_query = st.session_state.messages[-1]["content"] with st.chat_message("assistant"): with st.spinner("๐Ÿค” Searching context..."): try: img_keywords = ["show", "image", "diagram", "figure", "picture"] is_visual_request = any( k in last_query.lower() for k in img_keywords ) and use_images found_imgs = [] answer_text = "" if is_visual_request: # ๐Ÿ” Image search branch (non-streaming) found_imgs = rag.search_images( last_query, top_k=3, min_score=min_score, ) if found_imgs: answer_text = f"I found {len(found_imgs)} relevant visuals:" else: answer_text = "I couldn't find any relevant images." # Render once st.markdown(answer_text) else: # ๐Ÿง  Text answer branch (STREAMING) # Assume rag.answer_question returns a generator / stream. # st.write_stream will both display the chunks and return # the final concatenated string.[web:60] stream = rag.answer_question( last_query, top_k=top_k ) answer_text = st.write_stream(stream) # Render images if any if found_imgs: st.markdown("---") cols = st.columns(3) for idx, img in enumerate(found_imgs): with cols[idx % 3]: display_image_from_base64( img["image_base64"], width=220 ) # Persist assistant message in history st.session_state.messages.append( { "role": "assistant", "content": answer_text, "images": found_imgs, } ) except Exception as e: st.error(f"Error: {e}") st.session_state.messages.append( {"role": "assistant", "content": f"โŒ Error: {e}"} ) if __name__ == "__main__": main()