project2 / src /app.py
dnj0's picture
Update src/app.py
d3aa2b9 verified
import streamlit as st
import os
from pathlib import Path
from rag_pipeline import RAGPipeline
import time
# Page configuration
st.set_page_config(
page_title="Local Multimodal RAG",
page_icon="πŸ“š",
layout="wide",
initial_sidebar_state="expanded"
)
st.title("πŸ“š Local Multimodal RAG System")
st.markdown("**Analyze PDF documents locally with Mistral + CLIP embeddings**")
# Initialize session state
if "uploaded_files" not in st.session_state:
st.session_state.uploaded_files = []
if "rag_pipeline" not in st.session_state:
st.session_state.rag_pipeline = None
if "last_upload_time" not in st.session_state:
st.session_state.last_upload_time = 0
# Sidebar configuration
with st.sidebar:
st.header("βš™οΈ Configuration")
pdf_dir = st.text_input(
"πŸ“ PDF Directory",
value="./pdfs",
help="Path to directory containing PDF files"
)
# Ensure directory exists
os.makedirs(pdf_dir, exist_ok=True)
device = st.selectbox(
"πŸ–₯️ Device",
["cpu", "cuda"],
help="Device for model inference"
)
n_context_docs = st.slider(
"πŸ“„ Context Documents",
min_value=1,
max_value=10,
value=3,
help="Number of documents to retrieve for context"
)
st.divider()
# PDF Upload Section with Form
st.subheader("πŸ“€ Upload PDF Files")
# Use a form to separate file upload from submission
with st.form("pdf_upload_form", clear_on_submit=True):
uploaded_pdfs = st.file_uploader(
"Choose PDF files to upload",
type="pdf",
accept_multiple_files=True,
help="Select one or more PDF files to add to the system"
)
submit_button = st.form_submit_button("⬆️ Upload PDFs", use_container_width=True)
if submit_button and uploaded_pdfs:
upload_successful = True
uploaded_count = 0
for uploaded_file in uploaded_pdfs:
try:
file_path = os.path.join(pdf_dir, uploaded_file.name)
# Save file to disk
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.session_state.uploaded_files.append(uploaded_file.name)
uploaded_count += 1
except Exception as e:
st.error(f"Failed to upload {uploaded_file.name}: {str(e)}")
upload_successful = False
if upload_successful and uploaded_count > 0:
st.session_state.last_upload_time = time.time()
st.success(f"βœ… Uploaded {uploaded_count} PDF(s) successfully!")
st.info("πŸ“Œ Click 'Reload & Index PDFs' below to process them.")
# Don't call st.rerun() here - let form handle clear_on_submit
st.divider()
# Display uploaded files
pdf_files = list(Path(pdf_dir).glob("*.pdf"))
if pdf_files:
st.subheader(f"πŸ“š Documents ({len(pdf_files)})")
for pdf_file in pdf_files:
col1, col2 = st.columns([4, 1])
with col1:
st.write(f"β€’ {pdf_file.name}")
with col2:
if st.button("πŸ—‘οΈ", key=f"delete_{pdf_file.name}", help="Delete this file"):
try:
os.remove(pdf_file)
st.session_state.rag_pipeline = None # Clear pipeline
st.success(f"Deleted {pdf_file.name}")
time.sleep(0.5)
st.rerun()
except Exception as e:
st.error(f"Failed to delete: {str(e)}")
else:
st.info("πŸ“­ No PDF files in directory yet")
st.divider()
# Reload/Index button
col1, col2 = st.columns(2)
with col1:
if st.button("πŸ”„ Reload & Index", use_container_width=True):
st.session_state.rag_pipeline = None # Clear cached pipeline
st.rerun()
with col2:
if st.button("πŸ—‘οΈ Clear All", use_container_width=True):
# Delete all PDFs
for pdf_file in Path(pdf_dir).glob("*.pdf"):
try:
os.remove(pdf_file)
except:
pass
st.session_state.rag_pipeline = None
st.session_state.uploaded_files = []
st.success("All PDFs cleared")
time.sleep(0.5)
st.rerun()
# Initialize pipeline
@st.cache_resource
def init_rag_pipeline(_device, _pdf_dir):
"""Initialize RAG pipeline (cached)"""
os.makedirs(_pdf_dir, exist_ok=True)
pdf_files = list(Path(_pdf_dir).glob("*.pdf"))
if not pdf_files:
return None, f"No PDF files found in {_pdf_dir}"
try:
with st.spinner("⏳ Initializing models..."):
pipeline = RAGPipeline(pdf_dir=_pdf_dir, device=_device)
with st.spinner("⏳ Indexing PDFs..."):
pipeline.index_pdfs()
return pipeline, None
except Exception as e:
return None, str(e)
# Get or initialize pipeline
if st.session_state.rag_pipeline is None:
pdf_files = list(Path(pdf_dir).glob("*.pdf"))
if pdf_files:
pipeline, error = init_rag_pipeline(device, pdf_dir)
if error:
st.error(f"❌ Error: {error}")
st.stop()
st.session_state.rag_pipeline = pipeline
else:
st.warning("πŸ“­ No PDF files found")
st.info("""
**How to get started:**
1. πŸ“€ Upload PDF files using the sidebar file uploader
2. βœ… Click 'Upload PDFs' to save them
3. πŸ”„ Click 'Reload & Index PDFs' to process
4. ❓ Ask questions in the Q&A tab
""")
st.stop()
else:
pipeline = st.session_state.rag_pipeline
# Main content
if pipeline:
# Tabs
tab1, tab2, tab3 = st.tabs(["❓ Q&A", "πŸ“Š Summary", "πŸ“– Retrieval"])
# Tab 1: Question Answering
with tab1:
st.subheader("Ask Questions about Your Documents")
question = st.text_area(
"Your question (in Russian or English):",
height=100,
placeholder="What is this document about? What are the main points? Etc.",
key="qa_question"
)
col1, col2 = st.columns(2)
with col1:
get_answer_btn = st.button("πŸ” Get Answer", use_container_width=True)
with col2:
clear_btn = st.button("πŸ—‘οΈ Clear", use_container_width=True)
if clear_btn:
st.rerun()
if get_answer_btn:
if question.strip():
with st.spinner("⏳ Retrieving documents and generating answer..."):
try:
result = pipeline.answer_question(question, n_context_docs=n_context_docs)
except Exception as e:
st.error(f"Error: {str(e)}")
result = None
if result and result.get("answer"):
st.success("βœ“ Answer generated!")
st.subheader("πŸ“ Answer")
st.write(result["answer"])
with st.expander("πŸ“š Sources Used"):
for i, source in enumerate(result["sources"], 1):
st.write(f"{i}. {source}")
col1, col2 = st.columns(2)
with col1:
st.metric("Documents Used", result.get("context_used", 0))
with col2:
st.metric("Answer Length", len(result["answer"]))
else:
st.warning("Please enter a question")
# Tab 2: Document Summary
with tab2:
st.subheader("Summary of Indexed Documents")
if st.button("πŸ“Š Generate Summary", use_container_width=True):
with st.spinner("⏳ Generating summary..."):
try:
summary = pipeline.summarize_documents()
st.success("βœ“ Summary generated!")
st.subheader("πŸ“„ Document Summary")
st.write(summary)
except Exception as e:
st.error(f"Error: {str(e)}")
# Tab 3: Document Retrieval
with tab3:
st.subheader("Search and Retrieve Documents")
search_query = st.text_input(
"Search query:",
placeholder="Enter search terms...",
key="retrieval_search"
)
col1, col2 = st.columns(2)
with col1:
search_btn = st.button("πŸ”Ž Search", use_container_width=True)
with col2:
clear_search_btn = st.button("Clear Search", use_container_width=True)
if clear_search_btn:
st.rerun()
if search_btn:
if search_query.strip():
with st.spinner("⏳ Searching..."):
try:
results = pipeline.retrieve_documents(search_query, n_results=n_context_docs)
except Exception as e:
st.error(f"Search error: {str(e)}")
results = []
if results:
st.success(f"βœ“ Found {len(results)} documents")
for i, doc in enumerate(results, 1):
with st.expander(f"πŸ“„ Document {i} - {doc['source']}", expanded=(i==1)):
st.write(doc["content"])
else:
st.warning("No documents found matching your query")
else:
st.warning("Please enter a search query")
# Footer
st.divider()
with st.expander("ℹ️ System Information"):
info = pipeline.vector_store.get_collection_info()
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("πŸ“š Chunks", info.get("document_count", 0))
with col2:
st.metric("πŸ–₯️ Device", device.upper())
with col3:
st.metric("πŸ” Context", n_context_docs)
with col4:
pdf_count = len(list(Path(pdf_dir).glob("*.pdf")))
st.metric("πŸ“ PDFs", pdf_count)