rag / app.py
jessica45's picture
Rename main.py to app.py
58dad37 verified
import streamlit as st
import os
import tempfile
import hashlib
from typing import List
from dotenv import load_dotenv
from rag_with_gemini import RAGSystem
# Load environment variables
load_dotenv()
# --- PAGE CONFIG ---
st.set_page_config(
page_title="RAG Document Assistant",
page_icon="πŸ€–",
layout="wide",
initial_sidebar_state="expanded"
)
# --- SESSION STATE INIT ---
def initialize_session_state():
if 'rag_system' not in st.session_state:
st.session_state.rag_system = None
if 'documents_processed' not in st.session_state:
st.session_state.documents_processed = []
# store SHA256 hashes of processed files to avoid reprocessing the same file in a session
if 'processed_hashes' not in st.session_state:
st.session_state.processed_hashes = set()
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'processing_status' not in st.session_state:
st.session_state.processing_status = ""
if 'system_initialized' not in st.session_state:
st.session_state.system_initialized = False
# --- RAG SYSTEM INIT ---
def initialize_rag_system():
if st.session_state.system_initialized:
return True
try:
gemini_api_key = os.getenv('GEMINI_API_KEY')
qdrant_url = os.getenv('QDRANT_URL')
qdrant_api_key = os.getenv('QDRANT_API_KEY')
if not gemini_api_key or not qdrant_url or not qdrant_api_key:
st.error("❌ Missing API keys in your .env file.")
return False
with st.spinner("πŸš€ Initializing RAG system..."):
rag_system = RAGSystem(gemini_api_key, qdrant_url, qdrant_api_key)
st.session_state.rag_system = rag_system
st.session_state.system_initialized = True
return True
except Exception as e:
st.error(f"❌ Initialization error: {e}")
return False
# --- DOCUMENT PROCESSING ---
def process_uploaded_files(uploaded_files):
if not uploaded_files or not st.session_state.rag_system:
return False
try:
temp_paths = []
to_process = []
skipped = []
# Determine which files are new by hashing contents
for uploaded_file in uploaded_files:
data = uploaded_file.getvalue()
h = hashlib.sha256(data).hexdigest()
if h in st.session_state.processed_hashes:
skipped.append(uploaded_file.name)
continue
# write temp file for processing
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
tmp.write(data)
temp_paths.append(tmp.name)
to_process.append((uploaded_file.name, h))
# If there are no new files to process, short-circuit
if not temp_paths:
st.session_state.processing_status = f"⚠️ No new files to process. Skipped: {', '.join(skipped)}" if skipped else "⚠️ No files provided."
return True
with st.spinner("πŸ“„ Processing documents..."):
success = st.session_state.rag_system.add_documents(temp_paths)
for path in temp_paths:
try:
os.unlink(path)
except:
pass
if success:
# record processed filenames and their hashes
for name, h in to_process:
st.session_state.documents_processed.append(name)
st.session_state.processed_hashes.add(h)
# if some were skipped, include that in the status
status_msg = f"βœ… Processed {len(to_process)} documents!"
if skipped:
status_msg += f" Skipped {len(skipped)} duplicate(s): {', '.join(skipped)}"
st.session_state.processing_status = status_msg
return True
else:
st.session_state.processing_status = "❌ Failed to process documents."
return False
except Exception as e:
st.session_state.processing_status = f"❌ Error: {str(e)}"
return False
# --- CHAT DISPLAY ---
def display_chat_message(role: str, content: str, sources: List[str] = None):
avatar_url = (
"https://cdn-icons-png.flaticon.com/512/4712/4712035.png"
if role == "assistant"
else "https://cdn-icons-png.flaticon.com/512/1077/1077012.png"
)
with st.chat_message(role, avatar=avatar_url):
st.markdown(content)
# --- MAIN ---
def main():
initialize_session_state()
st.markdown('<h1 class="main-header">RAG Document Assistant</h1>', unsafe_allow_html=True)
if not initialize_rag_system():
st.stop()
# Sidebar
with st.sidebar:
st.markdown("### πŸ“ Upload Documents")
uploaded_files = st.file_uploader("Choose files", type=['pdf', 'txt', 'docx'], accept_multiple_files=True)
if uploaded_files and st.button("πŸ“€ Process Documents"):
if process_uploaded_files(uploaded_files):
st.rerun()
if st.session_state.processing_status:
msg = st.session_state.processing_status
cls = "success-message" if "βœ…" in msg else "error-message"
st.markdown(f'<div class="{cls}">{msg}</div>', unsafe_allow_html=True)
if st.session_state.documents_processed:
st.markdown("### βœ… Processed Files")
for doc in st.session_state.documents_processed:
st.write(f"- {doc}")
if st.button("πŸ—‘οΈ Clear Chat"):
st.session_state.chat_history = []
st.rerun()
if not st.session_state.chat_history and not st.session_state.documents_processed:
st.markdown("""
<div style="text-align:center; padding:3rem; color:#9ca3af;">
<h3>πŸ‘‹ Welcome to your RAG Assistant</h3>
<p>Upload documents in the sidebar, then ask me anything about their content.</p>
</div>
""", unsafe_allow_html=True)
for message in st.session_state.chat_history:
display_chat_message(message["role"], message["content"], message.get("sources", []))
# Chat input
if prompt := st.chat_input("πŸ’¬ Ask me anything..."):
if not st.session_state.documents_processed:
st.warning("⚠️ Upload and process documents first!")
return
st.session_state.chat_history.append({"role": "user", "content": prompt})
display_chat_message("user", prompt)
with st.chat_message("assistant"):
with st.spinner("πŸ€” Thinking..."):
try:
result = st.session_state.rag_system.query(prompt)
st.markdown(result['answer'])
# if result['sources']:
# with st.expander("πŸ“š Sources"):
# for i, src in enumerate(result['sources'], 1):
# st.write(f"{i}. {os.path.basename(src)}")
st.session_state.chat_history.append({
"role": "assistant",
"content": result['answer'],
"sources": result['sources']
})
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
st.error(error_msg)
st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
if __name__ == "__main__":
main()