Spaces:
Sleeping
Sleeping
| # app_with_upload_simple.py | |
| import streamlit as st | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from datetime import datetime | |
| import base64 | |
| # Setup logging | |
| logging.getLogger("pdfminer").setLevel(logging.ERROR) | |
| from pdf_processor import PDFProcessor, prepare_documents_for_embedding | |
| from embeddings_handler import CLIPLangChainEmbeddings | |
| from vectorstore_manager import VectorStoreManager | |
| from rag_chain import RAGChain | |
| from langchain_core.documents import Document | |
| # ============================================================================ | |
| # PAGE CONFIGURATION | |
| # ============================================================================ | |
| st.set_page_config( | |
| page_title="Multimodal RAG Assistant", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| .main { padding: 2rem; } | |
| .stTabs [data-baseweb="tab-list"] { gap: 2rem; } | |
| .metric-card { background-color: #f8f9fa; padding: 15px; border-radius: 5px; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ============================================================================ | |
| # SESSION STATE INITIALIZATION | |
| # ============================================================================ | |
| if "processor" not in st.session_state: | |
| st.session_state.processor = None | |
| if "vector_store" not in st.session_state: | |
| st.session_state.vector_store = None | |
| if "rag_chain" not in st.session_state: | |
| st.session_state.rag_chain = None | |
| if "embeddings" not in st.session_state: | |
| st.session_state.embeddings = None | |
| if "documents_processed" not in st.session_state: | |
| st.session_state.documents_processed = 0 | |
| if "extracted_content" not in st.session_state: | |
| st.session_state.extracted_content = [] | |
| # ============================================================================ | |
| # HELPER FUNCTIONS | |
| # ============================================================================ | |
| def init_processor(pdf_dir="./pdfs"): | |
| """Initialize PDF processor.""" | |
| return PDFProcessor(pdf_dir=pdf_dir) | |
| def init_embeddings(): | |
| """Initialize CLIP embeddings.""" | |
| return CLIPLangChainEmbeddings(model_name="ViT-B-32", pretrained="openai") | |
| def init_vector_store(embeddings): | |
| """Initialize vector store.""" | |
| return VectorStoreManager( | |
| persist_dir="./chroma_db", | |
| collection_name="pdf_documents", | |
| embeddings=embeddings | |
| ) | |
| def save_uploaded_files(uploaded_files, target_dir="./pdfs"): | |
| """Save uploaded files to directory.""" | |
| os.makedirs(target_dir, exist_ok=True) | |
| saved_files = [] | |
| for uploaded_file in uploaded_files: | |
| filepath = os.path.join(target_dir, uploaded_file.name) | |
| with open(filepath, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| saved_files.append(uploaded_file.name) | |
| return saved_files | |
| def get_document_stats(content): | |
| """Get statistics from extracted content.""" | |
| stats = { | |
| "pages": len(content.get("pages", [])), | |
| "total_text": sum(len(p.get("text", "")) for p in content.get("pages", [])), | |
| "tables": sum(len(p.get("tables", [])) for p in content.get("pages", [])), | |
| "images": sum(len(p.get("images", [])) for p in content.get("pages", [])) | |
| } | |
| return stats | |
| # ============================================================================ | |
| # MAIN APP | |
| # ============================================================================ | |
| st.title("π Multimodal PDF RAG Assistant") | |
| st.markdown("Upload PDFs, extract content, and query with multimodal embeddings.") | |
| # ============================================================================ | |
| # SIDEBAR - CONFIGURATION & UPLOAD | |
| # ============================================================================ | |
| with st.sidebar: | |
| st.header("βοΈ Configuration & Upload") | |
| # API Key | |
| api_key = st.text_input( | |
| "OpenAI API Key", | |
| type="password", | |
| value=os.getenv("OPENAI_API_KEY", ""), | |
| help="Your OpenAI API key" | |
| ) | |
| if api_key: | |
| os.environ["OPENAI_API_KEY"] = api_key | |
| st.markdown("---") | |
| # PDF Upload Section | |
| st.markdown("### π€ Upload PDFs") | |
| uploaded_pdfs = st.file_uploader( | |
| "Choose PDF files", | |
| type="pdf", | |
| accept_multiple_files=True, | |
| key="pdf_uploader", | |
| help="Upload one or more PDF files" | |
| ) | |
| if uploaded_pdfs: | |
| st.info(f"π¦ {len(uploaded_pdfs)} file(s) selected") | |
| if st.button("πΎ Save & Process PDFs", use_container_width=True): | |
| # Save files | |
| with st.spinner("π₯ Saving files..."): | |
| saved_files = save_uploaded_files(uploaded_pdfs) | |
| st.success(f"β Saved {len(saved_files)} file(s)") | |
| # Initialize processor | |
| with st.spinner("π Initializing processor..."): | |
| processor = init_processor() | |
| st.session_state.processor = processor | |
| # Process PDFs | |
| with st.spinner("π Processing PDFs..."): | |
| documents = processor.process_all_pdfs() | |
| st.session_state.extracted_content = documents | |
| st.session_state.documents_processed = len(documents) | |
| # Prepare chunks for embedding | |
| all_chunks = [] | |
| for doc_content in documents: | |
| chunks = prepare_documents_for_embedding(doc_content) | |
| all_chunks.extend(chunks) | |
| st.success(f"β Processed {len(documents)} PDF(s), {len(all_chunks)} chunks") | |
| # Initialize embeddings and vector store | |
| with st.spinner("π Creating vector store..."): | |
| embeddings = init_embeddings() | |
| st.session_state.embeddings = embeddings | |
| vector_store = init_vector_store(embeddings) | |
| st.session_state.vector_store = vector_store | |
| # Add documents to vector store | |
| docs_for_store = [ | |
| Document(page_content=text, metadata=meta) | |
| for text, meta in all_chunks | |
| ] | |
| vector_store.add_documents(docs_for_store) | |
| # Initialize RAG chain | |
| retriever = vector_store.get_retriever() | |
| rag_chain = RAGChain(retriever, api_key=api_key) | |
| st.session_state.rag_chain = rag_chain | |
| st.success("β Ready to query!") | |
| st.markdown("---") | |
| # Status | |
| st.markdown("### π Status") | |
| if st.session_state.documents_processed > 0: | |
| st.metric("Documents Processed", st.session_state.documents_processed) | |
| total_pages = sum( | |
| len(doc.get("pages", [])) | |
| for doc in st.session_state.extracted_content | |
| ) | |
| st.metric("Total Pages", total_pages) | |
| total_images = sum( | |
| sum(len(p.get("images", [])) for p in doc.get("pages", [])) | |
| for doc in st.session_state.extracted_content | |
| ) | |
| st.metric("Total Images", total_images) | |
| else: | |
| st.info("Upload and process PDFs to get started") | |
| # ============================================================================ | |
| # MAIN CONTENT AREA - TABS | |
| # ============================================================================ | |
| if st.session_state.documents_processed == 0: | |
| st.warning("π Upload PDFs in the sidebar to get started") | |
| else: | |
| tab1, tab2, tab3, tab4 = st.tabs(["π Query", "π Documents", "πΌοΈ Images", "βΉοΈ Info"]) | |
| # ==================================================================== | |
| # TAB 1: QUERY | |
| # ==================================================================== | |
| with tab1: | |
| st.header("π Ask Questions") | |
| st.markdown("Ask questions about your PDF documents.") | |
| if st.session_state.rag_chain is None: | |
| st.warning("β οΈ Please process PDFs first using the sidebar.") | |
| else: | |
| col1, col2 = st.columns([5, 1]) | |
| with col1: | |
| user_query = st.text_input( | |
| "Your question:", | |
| placeholder="What is this document about?", | |
| label_visibility="collapsed" | |
| ) | |
| with col2: | |
| search_button = st.button("π Search", use_container_width=True) | |
| if search_button and user_query: | |
| with st.spinner("π€ Searching and generating response..."): | |
| try: | |
| result = st.session_state.rag_chain.query(user_query) | |
| # Display answer | |
| st.markdown("### π Answer") | |
| st.markdown(result["answer"]) | |
| # Display sources | |
| if result["sources"]: | |
| st.markdown("### π Sources") | |
| for i, source in enumerate(result["sources"], 1): | |
| with st.expander(f"Source {i} - {source['metadata'].get('filename', 'Unknown')}"): | |
| st.markdown(f"**Type:** {source['metadata'].get('type', 'Unknown')}") | |
| st.markdown(f"**Page:** {source['metadata'].get('page', 'Unknown')}") | |
| st.markdown(f"**Content:** {source['content'][:500]}...") | |
| except Exception as e: | |
| st.error(f"β Error: {str(e)}") | |
| # ==================================================================== | |
| # TAB 2: DOCUMENTS | |
| # ==================================================================== | |
| with tab2: | |
| st.header("π Processed Documents") | |
| if not st.session_state.extracted_content: | |
| st.info("No documents processed yet.") | |
| else: | |
| # Overall statistics | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("Documents", len(st.session_state.extracted_content)) | |
| with col2: | |
| total_pages = sum( | |
| len(doc.get("pages", [])) | |
| for doc in st.session_state.extracted_content | |
| ) | |
| st.metric("Pages", total_pages) | |
| with col3: | |
| total_images = sum( | |
| sum(len(p.get("images", [])) for p in doc.get("pages", [])) | |
| for doc in st.session_state.extracted_content | |
| ) | |
| st.metric("Images", total_images) | |
| with col4: | |
| total_tables = sum( | |
| sum(len(p.get("tables", [])) for p in doc.get("pages", [])) | |
| for doc in st.session_state.extracted_content | |
| ) | |
| st.metric("Tables", total_tables) | |
| st.markdown("---") | |
| # Document details | |
| st.markdown("### π Document Details") | |
| for idx, doc in enumerate(st.session_state.extracted_content, 1): | |
| filename = doc.get("filename", f"Document {idx}") | |
| stats = get_document_stats(doc) | |
| with st.expander(f"π {filename}"): | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("Pages", stats["pages"]) | |
| with col2: | |
| st.metric("Images", stats["images"]) | |
| with col3: | |
| st.metric("Tables", stats["tables"]) | |
| with col4: | |
| st.metric("Text (KB)", round(stats["total_text"] / 1024, 1)) | |
| # Preview pages | |
| st.markdown("#### First 3 Pages Preview:") | |
| for page in doc.get("pages", [])[:3]: | |
| page_num = page.get("page_number") | |
| text = page.get("text", "")[:200] | |
| st.write(f"**Page {page_num}:** {text}...") | |
| # ==================================================================== | |
| # TAB 3: IMAGES | |
| # ==================================================================== | |
| with tab3: | |
| st.header("πΌοΈ Extracted Images") | |
| if not st.session_state.extracted_content: | |
| st.info("No images extracted yet.") | |
| else: | |
| image_count = 0 | |
| for doc_idx, doc in enumerate(st.session_state.extracted_content, 1): | |
| filename = doc.get("filename", f"Document {doc_idx}") | |
| for page in doc.get("pages", []): | |
| page_num = page.get("page_number") | |
| images = page.get("images", []) | |
| if images: | |
| st.markdown(f"### π {filename} - Page {page_num}") | |
| img_cols = st.columns(min(len(images), 2)) | |
| for idx, image in enumerate(images): | |
| with img_cols[idx % 2]: | |
| # Try to display image | |
| if image.get("base64"): | |
| try: | |
| st.image( | |
| f"data:image/{image.get('format', 'png')};base64,{image.get('base64')}", | |
| caption=f"Image {image.get('index')}", | |
| use_column_width=True | |
| ) | |
| image_count += 1 | |
| except Exception as e: | |
| st.warning(f"Could not display image: {e}") | |
| else: | |
| st.warning("No image data available") | |
| if image_count == 0: | |
| st.info("No images were successfully extracted from the PDFs.") | |
| # ==================================================================== | |
| # TAB 4: INFO | |
| # ==================================================================== | |
| with tab4: | |
| st.header("βΉοΈ System Information") | |
| st.markdown("### π― Features") | |
| features = { | |
| "β PDF Upload": "Upload multiple PDFs via UI", | |
| "β Text Extraction": "Extract text from documents", | |
| "β Table Detection": "Identify and extract tables", | |
| "β Image Extraction": "Extract and display images", | |
| "β CLIP Embeddings": "Multimodal embeddings", | |
| "β Vector Store": "ChromaDB for similarity search", | |
| "β RAG Chain": "LangChain with OpenAI", | |
| "β Russian Support": "Queries answered in Russian", | |
| } | |
| for feature, description in features.items(): | |
| st.markdown(f"**{feature}** - {description}") | |
| st.markdown("---") | |
| st.markdown("### π¦ System Status") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| if st.session_state.processor: | |
| st.success("β Processor Ready") | |
| else: | |
| st.warning("β οΈ Processor Not Initialized") | |
| with col2: | |
| if st.session_state.embeddings: | |
| st.success("β Embeddings Ready") | |
| else: | |
| st.warning("β οΈ Embeddings Not Initialized") | |
| with col3: | |
| if st.session_state.rag_chain: | |
| st.success("β RAG Chain Ready") | |
| else: | |
| st.warning("β οΈ RAG Chain Not Initialized") | |
| st.markdown("---") | |
| st.markdown("### π How It Works") | |
| st.markdown(""" | |
| 1. **Upload**: Select one or more PDF files | |
| 2. **Process**: System extracts text, tables, and images | |
| 3. **Embed**: Content converted to multimodal embeddings | |
| 4. **Store**: Vectors stored in ChromaDB | |
| 5. **Query**: Ask questions about documents | |
| 6. **Retrieve**: Relevant content fetched from store | |
| 7. **Generate**: OpenAI creates response | |
| 8. **Display**: Answer and sources shown in UI | |
| """) | |
| st.markdown("---") | |
| st.markdown("### π Technology Stack") | |
| tech_info = { | |
| "PDF Processing": "PyMuPDF, pdfplumber", | |
| "Embeddings": "CLIP ViT-B-32 (open-clip-torch)", | |
| "Vector Store": "ChromaDB", | |
| "LLM Framework": "LangChain", | |
| "Language Model": "OpenAI GPT-4o-mini", | |
| "Web UI": "Streamlit", | |
| } | |
| for tech, details in tech_info.items(): | |
| st.write(f"**{tech}:** {details}") | |
| # ============================================================================ | |
| # FOOTER | |
| # ============================================================================ | |
| st.markdown("---") | |
| st.markdown( | |
| "<div style='text-align: center; color: gray; font-size: 0.8rem;'>" | |
| "Multimodal RAG LLM System | Powered by LangChain, ChromaDB, CLIP, and OpenAI" | |
| "</div>", | |
| unsafe_allow_html=True | |
| ) | |