Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import time | |
| # Import Modular components | |
| from backend.rag import RAGEngine | |
| from backend.parser import EnrichedRagParser | |
| import tempfile | |
| # ========================================== | |
| # 1. Page Configuration & Professional CSS | |
| # ========================================== | |
| st.set_page_config( | |
| page_title="Multimodal RAG Assistant", | |
| page_icon="π€", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Production-ready CSS | |
| st.markdown(""" | |
| <style> | |
| .stChatMessage { | |
| background-color: var(--secondary-background-color); | |
| border: 1px solid rgba(128, 128, 128, 0.1); | |
| border-radius: 12px; | |
| padding: 1.5rem; | |
| margin-bottom: 1rem; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.05); | |
| } | |
| .stats-container { | |
| background-color: var(--secondary-background-color); | |
| border: 1px solid rgba(128, 128, 128, 0.2); | |
| border-radius: 10px; | |
| padding: 15px; | |
| margin-top: 10px; | |
| } | |
| .stats-header { | |
| font-weight: 600; | |
| color: var(--text-color); | |
| margin-bottom: 8px; | |
| display: block; | |
| } | |
| .stats-item { | |
| font-size: 0.9em; | |
| color: var(--text-color); | |
| opacity: 0.8; | |
| margin-bottom: 4px; | |
| display: flex; | |
| justify-content: space-between; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ========================================== | |
| # 2. Initialization & Helper Functions | |
| # ========================================== | |
| def initialize_rag_system(force_clean: bool = True): | |
| """Initialize the RAG system with caching.""" | |
| return RAGEngine(use_hybrid=True, force_clean=force_clean) | |
| def display_image_from_base64(base64_str: str, caption: str = "", width: int = 300): | |
| """Helper to decode and display base64 images.""" | |
| try: | |
| img_data = base64.b64decode(base64_str) | |
| img = Image.open(BytesIO(img_data)) | |
| st.image(img, caption=caption, width=width) | |
| except Exception as e: | |
| st.error(f"Failed to display image: {e}") | |
| # ========================================== | |
| # 3. Main Application | |
| # ========================================== | |
| def main(): | |
| # --- State Management --- | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "suggested_questions" not in st.session_state: | |
| st.session_state.suggested_questions = [] | |
| # Initialize Backend | |
| if "rag" not in st.session_state: | |
| with st.spinner("π Booting up AI System..."): | |
| st.session_state.rag = initialize_rag_system() | |
| rag: RAGEngine = st.session_state.rag | |
| # ========================================== | |
| # SIDEBAR: Control Panel | |
| # ========================================== | |
| with st.sidebar: | |
| st.header("π§ RAG Control Panel") | |
| # --- PDF Document Upload --- | |
| with st.expander("π Knowledge Base", expanded=True): | |
| uploaded_file = st.file_uploader( | |
| "Upload Document (PDF)", | |
| type=["pdf"], | |
| label_visibility="collapsed" | |
| ) | |
| if uploaded_file: | |
| # Temporary save for parsing | |
| # temp_dir = "/tmp" | |
| # os.makedirs(temp_dir, exist_ok=True) | |
| # save_path = os.path.join(temp_dir, uploaded_file.name) | |
| # with open(save_path, "wb") as f: | |
| # f.write(uploaded_file.getbuffer()) | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| tmp.write(uploaded_file.read()) | |
| file_path = tmp.name | |
| if st.button("π Process PDF", type="primary", use_container_width=True): | |
| try: | |
| with st.spinner("Analyzing PDF with Docling..."): | |
| parser = EnrichedRagParser() | |
| parsed_data = parser.process_document(file_path) | |
| with st.spinner("Ingesting into MongoDB..."): | |
| rag.ingest_data(parsed_data) | |
| # Generate Suggestions | |
| suggestions = rag.generate_suggested_questions(num_questions=6) | |
| st.session_state.suggested_questions = suggestions | |
| st.success(f"Processed: {uploaded_file.name}") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"β Error: {str(e)}") | |
| finally: | |
| # # β Always cleanup temp file | |
| # if os.path.exists(file_path): | |
| # os.remove(file_path) | |
| print("π§Ή Temp file deleted") | |
| st.rerun() | |
| st.markdown("---") | |
| # --- Suggested Questions --- | |
| if st.session_state.suggested_questions: | |
| st.subheader("π‘ Quick Questions") | |
| for idx, q in enumerate(st.session_state.suggested_questions): | |
| if st.button(q, key=f"sugg_{idx}", use_container_width=True): | |
| st.session_state.messages.append({"role": "user", "content": q}) | |
| st.rerun() | |
| st.markdown("---") | |
| # --- Settings --- | |
| with st.expander("βοΈ Search Settings"): | |
| top_k = st.slider("Max Results", 1, 10, 5) | |
| min_score = st.slider("Confidence Threshold", 0.0, 1.0, 0.6) | |
| use_images = st.toggle("Enable Image Search", value=True) | |
| # --- System Stats --- | |
| count = rag.collection.count_documents({}) | |
| st.markdown( | |
| f""" | |
| <div class="stats-container"> | |
| <span class="stats-header">π Database Status</span> | |
| <div class="stats-item"><span>Total Chunks:</span> <strong>{count}</strong></div> | |
| <div class="stats-item"><span>Embedding:</span> <strong>CLIP ViT-L/14</strong></div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Reset | |
| if st.button("ποΈ Clear Chat", type="secondary", use_container_width=True): | |
| st.session_state.messages = [] | |
| st.rerun() | |
| if st.button("β οΈ Delete Vector Collection", type="primary", use_container_width=True): | |
| with st.spinner("Deleting collection..."): | |
| rag.collection.delete_many({}) | |
| # Reset in-memory indices to match empty DB | |
| rag.bm25_index = None | |
| rag.bm25_doc_map = {} | |
| st.success("Vector Collection Deleted!") | |
| time.sleep(1) # Give user a moment to see the success message | |
| st.rerun() | |
| # ========================================== | |
| # MAIN: Chat Interface | |
| # ========================================== | |
| st.title("π€ Multimodal AI Assistant") | |
| if not st.session_state.messages: | |
| st.markdown( | |
| """ | |
| <div style="text-align: center; margin-top: 50px; opacity: 0.7;"> | |
| <h3>π Ready to help!</h3> | |
| <p>Upload a PDF in the sidebar to start.</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Render History | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| if "images" in msg and msg["images"]: | |
| st.markdown("---") | |
| cols = st.columns(3) | |
| for i, img in enumerate(msg["images"]): | |
| with cols[i % 3]: | |
| display_image_from_base64(img["image_base64"], width=220) | |
| # ========================================== | |
| # LOGIC: Input Handling | |
| # ========================================== | |
| user_input = st.chat_input("Type your question here...") | |
| if user_input: | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| st.rerun() | |
| # ========================================== | |
| # ASSISTANT: Streaming Response Logic | |
| # ========================================== | |
| if st.session_state.messages and st.session_state.messages[-1]["role"] == "user": | |
| last_query = st.session_state.messages[-1]["content"] | |
| with st.chat_message("assistant"): | |
| with st.spinner("π€ Searching context..."): | |
| try: | |
| img_keywords = ["show", "image", "diagram", "figure", "picture"] | |
| is_visual_request = any( | |
| k in last_query.lower() for k in img_keywords | |
| ) and use_images | |
| found_imgs = [] | |
| answer_text = "" | |
| if is_visual_request: | |
| # π Image search branch (non-streaming) | |
| found_imgs = rag.search_images( | |
| last_query, | |
| top_k=3, | |
| min_score=min_score, | |
| ) | |
| if found_imgs: | |
| answer_text = f"I found {len(found_imgs)} relevant visuals:" | |
| else: | |
| answer_text = "I couldn't find any relevant images." | |
| # Render once | |
| st.markdown(answer_text) | |
| else: | |
| # π§ Text answer branch (STREAMING) | |
| # Assume rag.answer_question returns a generator / stream. | |
| # st.write_stream will both display the chunks and return | |
| # the final concatenated string.[web:60] | |
| stream = rag.answer_question( | |
| last_query, | |
| top_k=top_k | |
| ) | |
| answer_text = st.write_stream(stream) | |
| # Render images if any | |
| if found_imgs: | |
| st.markdown("---") | |
| cols = st.columns(3) | |
| for idx, img in enumerate(found_imgs): | |
| with cols[idx % 3]: | |
| display_image_from_base64( | |
| img["image_base64"], width=220 | |
| ) | |
| # Persist assistant message in history | |
| st.session_state.messages.append( | |
| { | |
| "role": "assistant", | |
| "content": answer_text, | |
| "images": found_imgs, | |
| } | |
| ) | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| st.session_state.messages.append( | |
| {"role": "assistant", "content": f"β Error: {e}"} | |
| ) | |
| if __name__ == "__main__": | |
| main() | |