|
|
import streamlit as st |
|
|
import os |
|
|
from pathlib import Path |
|
|
from rag_pipeline import RAGPipeline |
|
|
import time |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Local Multimodal RAG", |
|
|
page_icon="π", |
|
|
layout="wide", |
|
|
initial_sidebar_state="expanded" |
|
|
) |
|
|
|
|
|
st.title("π Local Multimodal RAG System") |
|
|
st.markdown("**Analyze PDF documents locally with Mistral + CLIP embeddings**") |
|
|
|
|
|
|
|
|
if "uploaded_files" not in st.session_state: |
|
|
st.session_state.uploaded_files = [] |
|
|
if "rag_pipeline" not in st.session_state: |
|
|
st.session_state.rag_pipeline = None |
|
|
if "last_upload_time" not in st.session_state: |
|
|
st.session_state.last_upload_time = 0 |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("βοΈ Configuration") |
|
|
|
|
|
pdf_dir = st.text_input( |
|
|
"π PDF Directory", |
|
|
value="./pdfs", |
|
|
help="Path to directory containing PDF files" |
|
|
) |
|
|
|
|
|
|
|
|
os.makedirs(pdf_dir, exist_ok=True) |
|
|
|
|
|
device = st.selectbox( |
|
|
"π₯οΈ Device", |
|
|
["cpu", "cuda"], |
|
|
help="Device for model inference" |
|
|
) |
|
|
|
|
|
n_context_docs = st.slider( |
|
|
"π Context Documents", |
|
|
min_value=1, |
|
|
max_value=10, |
|
|
value=3, |
|
|
help="Number of documents to retrieve for context" |
|
|
) |
|
|
|
|
|
st.divider() |
|
|
|
|
|
|
|
|
st.subheader("π€ Upload PDF Files") |
|
|
|
|
|
|
|
|
with st.form("pdf_upload_form", clear_on_submit=True): |
|
|
uploaded_pdfs = st.file_uploader( |
|
|
"Choose PDF files to upload", |
|
|
type="pdf", |
|
|
accept_multiple_files=True, |
|
|
help="Select one or more PDF files to add to the system" |
|
|
) |
|
|
|
|
|
submit_button = st.form_submit_button("β¬οΈ Upload PDFs", use_container_width=True) |
|
|
|
|
|
if submit_button and uploaded_pdfs: |
|
|
upload_successful = True |
|
|
uploaded_count = 0 |
|
|
|
|
|
for uploaded_file in uploaded_pdfs: |
|
|
try: |
|
|
file_path = os.path.join(pdf_dir, uploaded_file.name) |
|
|
|
|
|
|
|
|
with open(file_path, "wb") as f: |
|
|
f.write(uploaded_file.getbuffer()) |
|
|
|
|
|
st.session_state.uploaded_files.append(uploaded_file.name) |
|
|
uploaded_count += 1 |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Failed to upload {uploaded_file.name}: {str(e)}") |
|
|
upload_successful = False |
|
|
|
|
|
if upload_successful and uploaded_count > 0: |
|
|
st.session_state.last_upload_time = time.time() |
|
|
st.success(f"β
Uploaded {uploaded_count} PDF(s) successfully!") |
|
|
st.info("π Click 'Reload & Index PDFs' below to process them.") |
|
|
|
|
|
|
|
|
st.divider() |
|
|
|
|
|
|
|
|
pdf_files = list(Path(pdf_dir).glob("*.pdf")) |
|
|
if pdf_files: |
|
|
st.subheader(f"π Documents ({len(pdf_files)})") |
|
|
|
|
|
for pdf_file in pdf_files: |
|
|
col1, col2 = st.columns([4, 1]) |
|
|
with col1: |
|
|
st.write(f"β’ {pdf_file.name}") |
|
|
with col2: |
|
|
if st.button("ποΈ", key=f"delete_{pdf_file.name}", help="Delete this file"): |
|
|
try: |
|
|
os.remove(pdf_file) |
|
|
st.session_state.rag_pipeline = None |
|
|
st.success(f"Deleted {pdf_file.name}") |
|
|
time.sleep(0.5) |
|
|
st.rerun() |
|
|
except Exception as e: |
|
|
st.error(f"Failed to delete: {str(e)}") |
|
|
else: |
|
|
st.info("π No PDF files in directory yet") |
|
|
|
|
|
st.divider() |
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
if st.button("π Reload & Index", use_container_width=True): |
|
|
st.session_state.rag_pipeline = None |
|
|
st.rerun() |
|
|
|
|
|
with col2: |
|
|
if st.button("ποΈ Clear All", use_container_width=True): |
|
|
|
|
|
for pdf_file in Path(pdf_dir).glob("*.pdf"): |
|
|
try: |
|
|
os.remove(pdf_file) |
|
|
except: |
|
|
pass |
|
|
st.session_state.rag_pipeline = None |
|
|
st.session_state.uploaded_files = [] |
|
|
st.success("All PDFs cleared") |
|
|
time.sleep(0.5) |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def init_rag_pipeline(_device, _pdf_dir): |
|
|
"""Initialize RAG pipeline (cached)""" |
|
|
os.makedirs(_pdf_dir, exist_ok=True) |
|
|
|
|
|
pdf_files = list(Path(_pdf_dir).glob("*.pdf")) |
|
|
if not pdf_files: |
|
|
return None, f"No PDF files found in {_pdf_dir}" |
|
|
|
|
|
try: |
|
|
with st.spinner("β³ Initializing models..."): |
|
|
pipeline = RAGPipeline(pdf_dir=_pdf_dir, device=_device) |
|
|
|
|
|
with st.spinner("β³ Indexing PDFs..."): |
|
|
pipeline.index_pdfs() |
|
|
|
|
|
return pipeline, None |
|
|
except Exception as e: |
|
|
return None, str(e) |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.rag_pipeline is None: |
|
|
pdf_files = list(Path(pdf_dir).glob("*.pdf")) |
|
|
|
|
|
if pdf_files: |
|
|
pipeline, error = init_rag_pipeline(device, pdf_dir) |
|
|
if error: |
|
|
st.error(f"β Error: {error}") |
|
|
st.stop() |
|
|
st.session_state.rag_pipeline = pipeline |
|
|
else: |
|
|
st.warning("π No PDF files found") |
|
|
st.info(""" |
|
|
**How to get started:** |
|
|
1. π€ Upload PDF files using the sidebar file uploader |
|
|
2. β
Click 'Upload PDFs' to save them |
|
|
3. π Click 'Reload & Index PDFs' to process |
|
|
4. β Ask questions in the Q&A tab |
|
|
""") |
|
|
st.stop() |
|
|
else: |
|
|
pipeline = st.session_state.rag_pipeline |
|
|
|
|
|
|
|
|
|
|
|
if pipeline: |
|
|
|
|
|
tab1, tab2, tab3 = st.tabs(["β Q&A", "π Summary", "π Retrieval"]) |
|
|
|
|
|
|
|
|
with tab1: |
|
|
st.subheader("Ask Questions about Your Documents") |
|
|
|
|
|
question = st.text_area( |
|
|
"Your question (in Russian or English):", |
|
|
height=100, |
|
|
placeholder="What is this document about? What are the main points? Etc.", |
|
|
key="qa_question" |
|
|
) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
get_answer_btn = st.button("π Get Answer", use_container_width=True) |
|
|
with col2: |
|
|
clear_btn = st.button("ποΈ Clear", use_container_width=True) |
|
|
|
|
|
if clear_btn: |
|
|
st.rerun() |
|
|
|
|
|
if get_answer_btn: |
|
|
if question.strip(): |
|
|
with st.spinner("β³ Retrieving documents and generating answer..."): |
|
|
try: |
|
|
result = pipeline.answer_question(question, n_context_docs=n_context_docs) |
|
|
except Exception as e: |
|
|
st.error(f"Error: {str(e)}") |
|
|
result = None |
|
|
|
|
|
if result and result.get("answer"): |
|
|
st.success("β Answer generated!") |
|
|
|
|
|
st.subheader("π Answer") |
|
|
st.write(result["answer"]) |
|
|
|
|
|
with st.expander("π Sources Used"): |
|
|
for i, source in enumerate(result["sources"], 1): |
|
|
st.write(f"{i}. {source}") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
st.metric("Documents Used", result.get("context_used", 0)) |
|
|
with col2: |
|
|
st.metric("Answer Length", len(result["answer"])) |
|
|
else: |
|
|
st.warning("Please enter a question") |
|
|
|
|
|
|
|
|
with tab2: |
|
|
st.subheader("Summary of Indexed Documents") |
|
|
|
|
|
if st.button("π Generate Summary", use_container_width=True): |
|
|
with st.spinner("β³ Generating summary..."): |
|
|
try: |
|
|
summary = pipeline.summarize_documents() |
|
|
st.success("β Summary generated!") |
|
|
st.subheader("π Document Summary") |
|
|
st.write(summary) |
|
|
except Exception as e: |
|
|
st.error(f"Error: {str(e)}") |
|
|
|
|
|
|
|
|
with tab3: |
|
|
st.subheader("Search and Retrieve Documents") |
|
|
|
|
|
search_query = st.text_input( |
|
|
"Search query:", |
|
|
placeholder="Enter search terms...", |
|
|
key="retrieval_search" |
|
|
) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
search_btn = st.button("π Search", use_container_width=True) |
|
|
with col2: |
|
|
clear_search_btn = st.button("Clear Search", use_container_width=True) |
|
|
|
|
|
if clear_search_btn: |
|
|
st.rerun() |
|
|
|
|
|
if search_btn: |
|
|
if search_query.strip(): |
|
|
with st.spinner("β³ Searching..."): |
|
|
try: |
|
|
results = pipeline.retrieve_documents(search_query, n_results=n_context_docs) |
|
|
except Exception as e: |
|
|
st.error(f"Search error: {str(e)}") |
|
|
results = [] |
|
|
|
|
|
if results: |
|
|
st.success(f"β Found {len(results)} documents") |
|
|
|
|
|
for i, doc in enumerate(results, 1): |
|
|
with st.expander(f"π Document {i} - {doc['source']}", expanded=(i==1)): |
|
|
st.write(doc["content"]) |
|
|
else: |
|
|
st.warning("No documents found matching your query") |
|
|
else: |
|
|
st.warning("Please enter a search query") |
|
|
|
|
|
|
|
|
st.divider() |
|
|
with st.expander("βΉοΈ System Information"): |
|
|
info = pipeline.vector_store.get_collection_info() |
|
|
col1, col2, col3, col4 = st.columns(4) |
|
|
with col1: |
|
|
st.metric("π Chunks", info.get("document_count", 0)) |
|
|
with col2: |
|
|
st.metric("π₯οΈ Device", device.upper()) |
|
|
with col3: |
|
|
st.metric("π Context", n_context_docs) |
|
|
with col4: |
|
|
pdf_count = len(list(Path(pdf_dir).glob("*.pdf"))) |
|
|
st.metric("π PDFs", pdf_count) |