import streamlit as st import os import tempfile import hashlib from typing import List from dotenv import load_dotenv from rag_with_gemini import RAGSystem # Load environment variables load_dotenv() # --- PAGE CONFIG --- st.set_page_config( page_title="RAG Document Assistant", page_icon="🤖", layout="wide", initial_sidebar_state="expanded" ) # --- SESSION STATE INIT --- def initialize_session_state(): if 'rag_system' not in st.session_state: st.session_state.rag_system = None if 'documents_processed' not in st.session_state: st.session_state.documents_processed = [] # store SHA256 hashes of processed files to avoid reprocessing the same file in a session if 'processed_hashes' not in st.session_state: st.session_state.processed_hashes = set() if 'chat_history' not in st.session_state: st.session_state.chat_history = [] if 'processing_status' not in st.session_state: st.session_state.processing_status = "" if 'system_initialized' not in st.session_state: st.session_state.system_initialized = False # --- RAG SYSTEM INIT --- def initialize_rag_system(): if st.session_state.system_initialized: return True try: gemini_api_key = os.getenv('GEMINI_API_KEY') qdrant_url = os.getenv('QDRANT_URL') qdrant_api_key = os.getenv('QDRANT_API_KEY') if not gemini_api_key or not qdrant_url or not qdrant_api_key: st.error("❌ Missing API keys in your .env file.") return False with st.spinner("🚀 Initializing RAG system..."): rag_system = RAGSystem(gemini_api_key, qdrant_url, qdrant_api_key) st.session_state.rag_system = rag_system st.session_state.system_initialized = True return True except Exception as e: st.error(f"❌ Initialization error: {e}") return False # --- DOCUMENT PROCESSING --- def process_uploaded_files(uploaded_files): if not uploaded_files or not st.session_state.rag_system: return False try: temp_paths = [] to_process = [] skipped = [] # Determine which files are new by hashing contents for uploaded_file in uploaded_files: data = uploaded_file.getvalue() h = hashlib.sha256(data).hexdigest() if h in st.session_state.processed_hashes: skipped.append(uploaded_file.name) continue # write temp file for processing with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp: tmp.write(data) temp_paths.append(tmp.name) to_process.append((uploaded_file.name, h)) # If there are no new files to process, short-circuit if not temp_paths: st.session_state.processing_status = f"⚠️ No new files to process. Skipped: {', '.join(skipped)}" if skipped else "⚠️ No files provided." return True with st.spinner("📄 Processing documents..."): success = st.session_state.rag_system.add_documents(temp_paths) for path in temp_paths: try: os.unlink(path) except: pass if success: # record processed filenames and their hashes for name, h in to_process: st.session_state.documents_processed.append(name) st.session_state.processed_hashes.add(h) # if some were skipped, include that in the status status_msg = f"✅ Processed {len(to_process)} documents!" if skipped: status_msg += f" Skipped {len(skipped)} duplicate(s): {', '.join(skipped)}" st.session_state.processing_status = status_msg return True else: st.session_state.processing_status = "❌ Failed to process documents." return False except Exception as e: st.session_state.processing_status = f"❌ Error: {str(e)}" return False # --- CHAT DISPLAY --- def display_chat_message(role: str, content: str, sources: List[str] = None): avatar_url = ( "https://cdn-icons-png.flaticon.com/512/4712/4712035.png" if role == "assistant" else "https://cdn-icons-png.flaticon.com/512/1077/1077012.png" ) with st.chat_message(role, avatar=avatar_url): st.markdown(content) # --- MAIN --- def main(): initialize_session_state() st.markdown('
Upload documents in the sidebar, then ask me anything about their content.