dnj0's picture
Update src/app.py
893bbbd verified
# app_with_upload_simple.py
import streamlit as st
import logging
import os
from pathlib import Path
from datetime import datetime
import base64
# Setup logging
logging.getLogger("pdfminer").setLevel(logging.ERROR)
from pdf_processor import PDFProcessor, prepare_documents_for_embedding
from embeddings_handler import CLIPLangChainEmbeddings
from vectorstore_manager import VectorStoreManager
from rag_chain import RAGChain
from langchain_core.documents import Document
# ============================================================================
# PAGE CONFIGURATION
# ============================================================================
st.set_page_config(
page_title="Multimodal RAG Assistant",
page_icon="πŸ“„",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS
st.markdown("""
<style>
.main { padding: 2rem; }
.stTabs [data-baseweb="tab-list"] { gap: 2rem; }
.metric-card { background-color: #f8f9fa; padding: 15px; border-radius: 5px; }
</style>
""", unsafe_allow_html=True)
# ============================================================================
# SESSION STATE INITIALIZATION
# ============================================================================
if "processor" not in st.session_state:
st.session_state.processor = None
if "vector_store" not in st.session_state:
st.session_state.vector_store = None
if "rag_chain" not in st.session_state:
st.session_state.rag_chain = None
if "embeddings" not in st.session_state:
st.session_state.embeddings = None
if "documents_processed" not in st.session_state:
st.session_state.documents_processed = 0
if "extracted_content" not in st.session_state:
st.session_state.extracted_content = []
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
@st.cache_resource
def init_processor(pdf_dir="./pdfs"):
"""Initialize PDF processor."""
return PDFProcessor(pdf_dir=pdf_dir)
@st.cache_resource
def init_embeddings():
"""Initialize CLIP embeddings."""
return CLIPLangChainEmbeddings(model_name="ViT-B-32", pretrained="openai")
@st.cache_resource
def init_vector_store(embeddings):
"""Initialize vector store."""
return VectorStoreManager(
persist_dir="./chroma_db",
collection_name="pdf_documents",
embeddings=embeddings
)
def save_uploaded_files(uploaded_files, target_dir="./pdfs"):
"""Save uploaded files to directory."""
os.makedirs(target_dir, exist_ok=True)
saved_files = []
for uploaded_file in uploaded_files:
filepath = os.path.join(target_dir, uploaded_file.name)
with open(filepath, "wb") as f:
f.write(uploaded_file.getbuffer())
saved_files.append(uploaded_file.name)
return saved_files
def get_document_stats(content):
"""Get statistics from extracted content."""
stats = {
"pages": len(content.get("pages", [])),
"total_text": sum(len(p.get("text", "")) for p in content.get("pages", [])),
"tables": sum(len(p.get("tables", [])) for p in content.get("pages", [])),
"images": sum(len(p.get("images", [])) for p in content.get("pages", []))
}
return stats
# ============================================================================
# MAIN APP
# ============================================================================
st.title("πŸ“„ Multimodal PDF RAG Assistant")
st.markdown("Upload PDFs, extract content, and query with multimodal embeddings.")
# ============================================================================
# SIDEBAR - CONFIGURATION & UPLOAD
# ============================================================================
with st.sidebar:
st.header("βš™οΈ Configuration & Upload")
# API Key
api_key = st.text_input(
"OpenAI API Key",
type="password",
value=os.getenv("OPENAI_API_KEY", ""),
help="Your OpenAI API key"
)
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
st.markdown("---")
# PDF Upload Section
st.markdown("### πŸ“€ Upload PDFs")
uploaded_pdfs = st.file_uploader(
"Choose PDF files",
type="pdf",
accept_multiple_files=True,
key="pdf_uploader",
help="Upload one or more PDF files"
)
if uploaded_pdfs:
st.info(f"πŸ“¦ {len(uploaded_pdfs)} file(s) selected")
if st.button("πŸ’Ύ Save & Process PDFs", use_container_width=True):
# Save files
with st.spinner("πŸ“₯ Saving files..."):
saved_files = save_uploaded_files(uploaded_pdfs)
st.success(f"βœ… Saved {len(saved_files)} file(s)")
# Initialize processor
with st.spinner("πŸ”„ Initializing processor..."):
processor = init_processor()
st.session_state.processor = processor
# Process PDFs
with st.spinner("πŸ“– Processing PDFs..."):
documents = processor.process_all_pdfs()
st.session_state.extracted_content = documents
st.session_state.documents_processed = len(documents)
# Prepare chunks for embedding
all_chunks = []
for doc_content in documents:
chunks = prepare_documents_for_embedding(doc_content)
all_chunks.extend(chunks)
st.success(f"βœ… Processed {len(documents)} PDF(s), {len(all_chunks)} chunks")
# Initialize embeddings and vector store
with st.spinner("πŸ”— Creating vector store..."):
embeddings = init_embeddings()
st.session_state.embeddings = embeddings
vector_store = init_vector_store(embeddings)
st.session_state.vector_store = vector_store
# Add documents to vector store
docs_for_store = [
Document(page_content=text, metadata=meta)
for text, meta in all_chunks
]
vector_store.add_documents(docs_for_store)
# Initialize RAG chain
retriever = vector_store.get_retriever()
rag_chain = RAGChain(retriever, api_key=api_key)
st.session_state.rag_chain = rag_chain
st.success("βœ… Ready to query!")
st.markdown("---")
# Status
st.markdown("### πŸ“Š Status")
if st.session_state.documents_processed > 0:
st.metric("Documents Processed", st.session_state.documents_processed)
total_pages = sum(
len(doc.get("pages", []))
for doc in st.session_state.extracted_content
)
st.metric("Total Pages", total_pages)
total_images = sum(
sum(len(p.get("images", [])) for p in doc.get("pages", []))
for doc in st.session_state.extracted_content
)
st.metric("Total Images", total_images)
else:
st.info("Upload and process PDFs to get started")
# ============================================================================
# MAIN CONTENT AREA - TABS
# ============================================================================
if st.session_state.documents_processed == 0:
st.warning("πŸ‘ˆ Upload PDFs in the sidebar to get started")
else:
tab1, tab2, tab3, tab4 = st.tabs(["πŸ” Query", "πŸ“Š Documents", "πŸ–ΌοΈ Images", "ℹ️ Info"])
# ====================================================================
# TAB 1: QUERY
# ====================================================================
with tab1:
st.header("πŸ” Ask Questions")
st.markdown("Ask questions about your PDF documents.")
if st.session_state.rag_chain is None:
st.warning("⚠️ Please process PDFs first using the sidebar.")
else:
col1, col2 = st.columns([5, 1])
with col1:
user_query = st.text_input(
"Your question:",
placeholder="What is this document about?",
label_visibility="collapsed"
)
with col2:
search_button = st.button("πŸ” Search", use_container_width=True)
if search_button and user_query:
with st.spinner("πŸ€– Searching and generating response..."):
try:
result = st.session_state.rag_chain.query(user_query)
# Display answer
st.markdown("### πŸ“ Answer")
st.markdown(result["answer"])
# Display sources
if result["sources"]:
st.markdown("### πŸ“š Sources")
for i, source in enumerate(result["sources"], 1):
with st.expander(f"Source {i} - {source['metadata'].get('filename', 'Unknown')}"):
st.markdown(f"**Type:** {source['metadata'].get('type', 'Unknown')}")
st.markdown(f"**Page:** {source['metadata'].get('page', 'Unknown')}")
st.markdown(f"**Content:** {source['content'][:500]}...")
except Exception as e:
st.error(f"❌ Error: {str(e)}")
# ====================================================================
# TAB 2: DOCUMENTS
# ====================================================================
with tab2:
st.header("πŸ“Š Processed Documents")
if not st.session_state.extracted_content:
st.info("No documents processed yet.")
else:
# Overall statistics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Documents", len(st.session_state.extracted_content))
with col2:
total_pages = sum(
len(doc.get("pages", []))
for doc in st.session_state.extracted_content
)
st.metric("Pages", total_pages)
with col3:
total_images = sum(
sum(len(p.get("images", [])) for p in doc.get("pages", []))
for doc in st.session_state.extracted_content
)
st.metric("Images", total_images)
with col4:
total_tables = sum(
sum(len(p.get("tables", [])) for p in doc.get("pages", []))
for doc in st.session_state.extracted_content
)
st.metric("Tables", total_tables)
st.markdown("---")
# Document details
st.markdown("### πŸ“„ Document Details")
for idx, doc in enumerate(st.session_state.extracted_content, 1):
filename = doc.get("filename", f"Document {idx}")
stats = get_document_stats(doc)
with st.expander(f"πŸ“‘ {filename}"):
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Pages", stats["pages"])
with col2:
st.metric("Images", stats["images"])
with col3:
st.metric("Tables", stats["tables"])
with col4:
st.metric("Text (KB)", round(stats["total_text"] / 1024, 1))
# Preview pages
st.markdown("#### First 3 Pages Preview:")
for page in doc.get("pages", [])[:3]:
page_num = page.get("page_number")
text = page.get("text", "")[:200]
st.write(f"**Page {page_num}:** {text}...")
# ====================================================================
# TAB 3: IMAGES
# ====================================================================
with tab3:
st.header("πŸ–ΌοΈ Extracted Images")
if not st.session_state.extracted_content:
st.info("No images extracted yet.")
else:
image_count = 0
for doc_idx, doc in enumerate(st.session_state.extracted_content, 1):
filename = doc.get("filename", f"Document {doc_idx}")
for page in doc.get("pages", []):
page_num = page.get("page_number")
images = page.get("images", [])
if images:
st.markdown(f"### πŸ“„ {filename} - Page {page_num}")
img_cols = st.columns(min(len(images), 2))
for idx, image in enumerate(images):
with img_cols[idx % 2]:
# Try to display image
if image.get("base64"):
try:
st.image(
f"data:image/{image.get('format', 'png')};base64,{image.get('base64')}",
caption=f"Image {image.get('index')}",
use_column_width=True
)
image_count += 1
except Exception as e:
st.warning(f"Could not display image: {e}")
else:
st.warning("No image data available")
if image_count == 0:
st.info("No images were successfully extracted from the PDFs.")
# ====================================================================
# TAB 4: INFO
# ====================================================================
with tab4:
st.header("ℹ️ System Information")
st.markdown("### 🎯 Features")
features = {
"βœ… PDF Upload": "Upload multiple PDFs via UI",
"βœ… Text Extraction": "Extract text from documents",
"βœ… Table Detection": "Identify and extract tables",
"βœ… Image Extraction": "Extract and display images",
"βœ… CLIP Embeddings": "Multimodal embeddings",
"βœ… Vector Store": "ChromaDB for similarity search",
"βœ… RAG Chain": "LangChain with OpenAI",
"βœ… Russian Support": "Queries answered in Russian",
}
for feature, description in features.items():
st.markdown(f"**{feature}** - {description}")
st.markdown("---")
st.markdown("### πŸ“¦ System Status")
col1, col2, col3 = st.columns(3)
with col1:
if st.session_state.processor:
st.success("βœ… Processor Ready")
else:
st.warning("⚠️ Processor Not Initialized")
with col2:
if st.session_state.embeddings:
st.success("βœ… Embeddings Ready")
else:
st.warning("⚠️ Embeddings Not Initialized")
with col3:
if st.session_state.rag_chain:
st.success("βœ… RAG Chain Ready")
else:
st.warning("⚠️ RAG Chain Not Initialized")
st.markdown("---")
st.markdown("### πŸš€ How It Works")
st.markdown("""
1. **Upload**: Select one or more PDF files
2. **Process**: System extracts text, tables, and images
3. **Embed**: Content converted to multimodal embeddings
4. **Store**: Vectors stored in ChromaDB
5. **Query**: Ask questions about documents
6. **Retrieve**: Relevant content fetched from store
7. **Generate**: OpenAI creates response
8. **Display**: Answer and sources shown in UI
""")
st.markdown("---")
st.markdown("### πŸ”— Technology Stack")
tech_info = {
"PDF Processing": "PyMuPDF, pdfplumber",
"Embeddings": "CLIP ViT-B-32 (open-clip-torch)",
"Vector Store": "ChromaDB",
"LLM Framework": "LangChain",
"Language Model": "OpenAI GPT-4o-mini",
"Web UI": "Streamlit",
}
for tech, details in tech_info.items():
st.write(f"**{tech}:** {details}")
# ============================================================================
# FOOTER
# ============================================================================
st.markdown("---")
st.markdown(
"<div style='text-align: center; color: gray; font-size: 0.8rem;'>"
"Multimodal RAG LLM System | Powered by LangChain, ChromaDB, CLIP, and OpenAI"
"</div>",
unsafe_allow_html=True
)