Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from pathlib import Path | |
| import tempfile | |
| import os | |
| import time | |
| from typing import List, Dict | |
| from pinecone import Pinecone | |
| from src.table_aware_chunker import TableRecursiveChunker | |
| from src.processor import TableProcessor | |
| from src.llm import LLMChat | |
| from src.embedding import EmbeddingModel | |
| from chonkie import RecursiveRules | |
| from src.vectordb import ChunkType, process_documents, ingest_data, PineconeRetriever | |
| # Custom CSS for better UI | |
| st.set_page_config( | |
| page_title="📚 Table RAG Assistant", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| st.markdown(""" | |
| <style> | |
| .stApp { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| .chat-message { | |
| padding: 1.5rem; | |
| border-radius: 0.5rem; | |
| margin-bottom: 1rem; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .user-message { | |
| background-color: #f0f2f6; | |
| } | |
| .assistant-message { | |
| background-color: #e8f0fe; | |
| } | |
| .st-emotion-cache-1v0mbdj.e115fcil1 { | |
| border-radius: 0.5rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Initialize session states | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "documents_processed" not in st.session_state: | |
| st.session_state.documents_processed = False | |
| if "retriever" not in st.session_state: | |
| st.session_state.retriever = None | |
| if "llm" not in st.session_state: | |
| st.session_state.llm = None | |
| if "uploaded_files" not in st.session_state: | |
| st.session_state.uploaded_files = [] | |
| # Enhanced RAG Template using LangChain's ChatPromptTemplate | |
| RAG_TEMPLATE = [ | |
| { | |
| "role": "system", | |
| "content": """You are a knowledgeable assistant specialized in analyzing documents and tables. | |
| Your responses should be: | |
| - Accurate and based on the provided context | |
| - Concise (three sentences maximum) | |
| - Professional yet conversational | |
| - Include specific references to tables when relevant | |
| If you cannot find an answer in the context, acknowledge this clearly.""" | |
| }, | |
| { | |
| "role": "human", | |
| "content": "Context: {context}\n\nQuestion: {question}" | |
| } | |
| ] | |
| def simulate_streaming_response(text: str, delay: float = 0.02) -> str: | |
| """Simulate streaming response by yielding chunks of text with delay.""" | |
| words = text.split() | |
| result = "" | |
| for i, word in enumerate(words): | |
| result += word + " " | |
| time.sleep(delay) | |
| # Add punctuation pause | |
| if any(p in word for p in ['.', '!', '?', ',']): | |
| time.sleep(delay * 2) | |
| yield result | |
| def clear_pinecone_index(pc, index_name="vector-index"): | |
| """Clear the Pinecone index and reset app state.""" | |
| try: | |
| if pc.has_index(index_name): | |
| pc.delete_index(index_name) | |
| st.session_state.documents_processed = False | |
| st.session_state.retriever = None | |
| st.session_state.messages = [] | |
| st.session_state.llm = None | |
| st.session_state.uploaded_files = [] | |
| st.success("🧹 Database cleared successfully!") | |
| except Exception as e: | |
| st.error(f"❌ Error clearing database: {str(e)}") | |
| def format_context(results: List[Dict]) -> str: | |
| """Format retrieved results into context string.""" | |
| context_parts = [] | |
| for result in results: | |
| if result.get("chunk_type") == ChunkType.TABLE.value: | |
| table_text = f"Table: {result['markdown_table']}" | |
| if result.get("table_description"): | |
| table_text += f"\nDescription: {result['table_description']}" | |
| context_parts.append(table_text) | |
| else: | |
| context_parts.append(result.get("page_content", "")) | |
| return "\n\n".join(context_parts) | |
| def format_chat_message(message: Dict[str, str], results: List[Dict] = None) -> str: | |
| """Format chat message with retrieved tables in a visually appealing way.""" | |
| content = message["content"] | |
| if results: | |
| for result in results: | |
| if result.get("chunk_type") == ChunkType.TABLE.value: | |
| content += "\n\n---\n\n📊 **Relevant Table:**\n" + result['markdown_table'] | |
| return content | |
| def initialize_components(pinecone_api_key: str): | |
| """Initialize all required components with LangChain integration.""" | |
| try: | |
| # Initialize Pinecone | |
| pc = Pinecone(api_key=pinecone_api_key) | |
| # Initialize LangChain LLM with custom parameters | |
| llm = LLMChat( | |
| model_name="mistral:7b", | |
| temperature=0.3 # Lower temperature for more focused responses | |
| ) | |
| st.session_state.llm = llm | |
| # Initialize LangChain Embeddings | |
| embedder = EmbeddingModel("nomic-embed-text") | |
| # Initialize Chunker | |
| chunker = TableRecursiveChunker( | |
| tokenizer="gpt2", | |
| chunk_size=512, | |
| rules=RecursiveRules(), | |
| min_characters_per_chunk=12 | |
| ) | |
| # Initialize Processor | |
| processor = TableProcessor( | |
| llm_model=llm, | |
| embedding_model=embedder, | |
| batch_size=8 | |
| ) | |
| return pc, llm, embedder, chunker, processor | |
| except Exception as e: | |
| st.error(f"❌ Error initializing components: {str(e)}") | |
| return None, None, None, None, None | |
| def process_all_documents(uploaded_files, chunker, processor, pc, embedder): | |
| """Process uploaded documents with enhanced progress tracking.""" | |
| if not uploaded_files: | |
| st.warning("📤 Please upload at least one document.") | |
| return False | |
| try: | |
| temp_dir = tempfile.mkdtemp() | |
| file_paths = [] | |
| with st.status("📑 Processing Documents", expanded=True) as status: | |
| # Save uploaded files | |
| st.write("📁 Saving uploaded files...") | |
| for uploaded_file in uploaded_files: | |
| st.write(f"Saving {uploaded_file.name}...") | |
| file_path = Path(temp_dir) / uploaded_file.name | |
| with open(file_path, "wb") as f: | |
| f.write(uploaded_file.getvalue()) | |
| file_paths.append(str(file_path)) | |
| # Process documents | |
| st.write("🔄 Processing documents...") | |
| processed_chunks = process_documents( | |
| file_paths=file_paths, | |
| chunker=chunker, | |
| processor=processor, | |
| output_path='./output.md' | |
| ) | |
| # Ingest data | |
| st.write("📥 Ingesting data to vector database...") | |
| ingest_data( | |
| processed_chunks=processed_chunks, | |
| embedding_model=embedder, | |
| pinecone_client=pc | |
| ) | |
| # Setup retriever | |
| st.write("🎯 Setting up retriever...") | |
| st.session_state.retriever = PineconeRetriever( | |
| pinecone_client=pc, | |
| index_name="vector-index", | |
| namespace="rag", | |
| embedding_model=embedder, | |
| llm_model=st.session_state.llm | |
| ) | |
| st.session_state.documents_processed = True | |
| status.update(label="✅ Processing complete!", state="complete", expanded=False) | |
| return True | |
| except Exception as e: | |
| st.error(f"❌ Error processing documents: {str(e)}") | |
| return False | |
| finally: | |
| # Cleanup | |
| for file_path in file_paths: | |
| try: | |
| os.remove(file_path) | |
| except Exception: | |
| pass | |
| try: | |
| os.rmdir(temp_dir) | |
| except Exception: | |
| pass | |
| def main(): | |
| st.title("📚 Table RAG Assistant") | |
| st.markdown("---") | |
| pc = None | |
| # Sidebar Configuration with improved styling | |
| with st.sidebar: | |
| st.title("⚙️ Configuration") | |
| pinecone_api_key = st.text_input("🔑 Enter Pinecone API Key:", type="password") | |
| st.markdown("---") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("🧹 Clear DB", use_container_width=True): | |
| clear_pinecone_index(pc) | |
| with col2: | |
| if st.button("🗑️ Clear Chat", use_container_width=True): | |
| st.session_state.messages = [] | |
| st.session_state.llm.clear_history() | |
| st.rerun() | |
| # Display uploaded files | |
| if st.session_state.uploaded_files: | |
| st.markdown("---") | |
| st.subheader("📁 Uploaded Files") | |
| for file in st.session_state.uploaded_files: | |
| st.write(f"- {file.name}") | |
| pc = None | |
| if not pinecone_api_key: | |
| st.sidebar.warning("⚠️ Please enter Pinecone API key to continue.") | |
| st.stop() | |
| # Initialize components if not already done | |
| if st.session_state.retriever is None: | |
| pc, llm, embedder, chunker, processor = initialize_components(pinecone_api_key) | |
| clear_pinecone_index(pc) | |
| if None in (pc, llm, embedder, chunker, processor): | |
| st.stop() | |
| # Document Upload Section with improved UI | |
| if not st.session_state.documents_processed: | |
| st.header("📄 Document Upload") | |
| st.markdown("Upload your documents to get started. Supported formats: PDF, DOCX, TXT, CSV, XLSX") | |
| uploaded_files = st.file_uploader( | |
| "Drop your files here", | |
| accept_multiple_files=True, | |
| type=["pdf", "docx", "txt", "csv", "xlsx"] | |
| ) | |
| if uploaded_files: | |
| st.session_state.uploaded_files = uploaded_files | |
| if st.button("🚀 Process Documents", use_container_width=True): | |
| if process_all_documents(uploaded_files, chunker, processor, pc, embedder): | |
| st.success("✨ Documents processed successfully!") | |
| # Enhanced Chat Interface with Simulated Streaming | |
| if st.session_state.documents_processed: | |
| st.header("💬 Chat Interface") | |
| st.markdown("Ask questions about your documents and tables") | |
| # Display chat history with improved styling | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(format_chat_message(message, message.get("results"))) | |
| # Chat input with simulated streaming | |
| if prompt := st.chat_input("Ask a question..."): | |
| # Display user message | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Generate response with simulated streaming | |
| with st.chat_message("assistant"): | |
| response_placeholder = st.empty() | |
| with st.spinner("🤔 Thinking..."): | |
| # Retrieve relevant content | |
| results = st.session_state.retriever.invoke( | |
| question=prompt, | |
| top_k=3 | |
| ) | |
| # Format context and get response from LLM | |
| context = format_context(results) | |
| chat = st.session_state.llm | |
| input_vars = { | |
| "question": prompt, | |
| "context": context | |
| } | |
| # Get full response first | |
| full_response = chat.chat_with_template(RAG_TEMPLATE, input_vars) | |
| # Simulate streaming of the response | |
| for partial_response in simulate_streaming_response(full_response): | |
| response_placeholder.markdown(partial_response + "▌") | |
| # Display final response with tables | |
| formatted_response = format_chat_message( | |
| {"role": "assistant", "content": full_response}, | |
| results | |
| ) | |
| response_placeholder.markdown(formatted_response) | |
| # Save to chat history | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": full_response, | |
| "results": results | |
| }) | |
| if __name__ == "__main__": | |
| main() |