Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import tempfile | |
| from typing import List | |
| from unified_document_processor import UnifiedDocumentProcessor, CustomEmbeddingFunction | |
| import chromadb | |
| from chromadb.config import Settings | |
| from groq import Groq | |
| def initialize_session_state(): | |
| """Initialize all session state variables""" | |
| if 'CHROMADB_DIR' not in st.session_state: | |
| st.session_state.CHROMADB_DIR = os.path.join(os.getcwd(), 'chromadb_data') | |
| os.makedirs(st.session_state.CHROMADB_DIR, exist_ok=True) | |
| if 'processed_files' not in st.session_state: | |
| st.session_state.processed_files = set() | |
| if 'processor' not in st.session_state: | |
| try: | |
| st.session_state.processor = None # Will be initialized in StreamlitDocProcessor | |
| except Exception as e: | |
| st.error(f"Error initializing processor: {str(e)}") | |
| class StreamlitDocProcessor: | |
| def __init__(self): | |
| if st.session_state.processor is None: | |
| try: | |
| groq_api_key = st.secrets["GROQ_API_KEY"] | |
| # Initialize processor with persistent ChromaDB | |
| st.session_state.processor = self.initialize_processor(groq_api_key) | |
| # Update processed files after initializing processor | |
| st.session_state.processed_files = self.get_processed_files() | |
| except Exception as e: | |
| st.error(f"Error initializing processor: {str(e)}") | |
| return | |
| def initialize_processor(self, groq_api_key): | |
| """Initialize the processor with persistent ChromaDB""" | |
| class PersistentUnifiedDocumentProcessor(UnifiedDocumentProcessor): | |
| def __init__(self, api_key, collection_name="unified_content", persist_dir=None): | |
| self.groq_client = Groq(api_key=api_key) | |
| self.max_elements_per_chunk = 50 | |
| self.pdf_chunk_size = 500 | |
| self.pdf_overlap = 50 | |
| self._initialize_nltk() | |
| # Initialize persistent ChromaDB | |
| self.chroma_client = chromadb.PersistentClient( | |
| path=persist_dir, | |
| settings=Settings( | |
| allow_reset=True, | |
| is_persistent=True | |
| ) | |
| ) | |
| # Get or create collection | |
| try: | |
| self.collection = self.chroma_client.get_collection( | |
| name=collection_name, | |
| embedding_function=CustomEmbeddingFunction() | |
| ) | |
| except: | |
| self.collection = self.chroma_client.create_collection( | |
| name=collection_name, | |
| embedding_function=CustomEmbeddingFunction() | |
| ) | |
| return PersistentUnifiedDocumentProcessor( | |
| groq_api_key, | |
| persist_dir=st.session_state.CHROMADB_DIR | |
| ) | |
| def get_processed_files(self) -> set: | |
| """Get list of processed files from ChromaDB""" | |
| try: | |
| if st.session_state.processor: | |
| available_files = st.session_state.processor.get_available_files() | |
| return set(available_files['pdf'] + available_files['xml']) | |
| return set() | |
| except Exception as e: | |
| st.error(f"Error getting processed files: {str(e)}") | |
| return set() | |
| def run(self): | |
| st.title("AAS Assistant") | |
| # Create sidebar for navigation | |
| page = st.sidebar.selectbox( | |
| "Choose a page", | |
| ["Upload & Process", "Query"] | |
| ) | |
| if page == "Upload & Process": | |
| self.upload_and_process_page() | |
| else: | |
| self.qa_page() | |
| def upload_and_process_page(self): | |
| st.header("Upload and Process Documents") | |
| # File uploader | |
| uploaded_files = st.file_uploader( | |
| "Upload PDF or XML files", | |
| type=['pdf', 'xml'], | |
| accept_multiple_files=True | |
| ) | |
| if uploaded_files: | |
| for uploaded_file in uploaded_files: | |
| # Create progress bar | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| if uploaded_file.name not in st.session_state.processed_files: | |
| try: | |
| # Create a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: | |
| tmp_file.write(uploaded_file.getbuffer()) | |
| temp_path = tmp_file.name | |
| # Process the file | |
| status_text.text(f'Processing {uploaded_file.name}...') | |
| progress_bar.progress(25) | |
| result = st.session_state.processor.process_file(temp_path) | |
| progress_bar.progress(75) | |
| if result['success']: | |
| st.session_state.processed_files.add(uploaded_file.name) | |
| progress_bar.progress(100) | |
| status_text.success(f"Successfully processed {uploaded_file.name}") | |
| else: | |
| progress_bar.progress(100) | |
| status_text.error(f"Failed to process {uploaded_file.name}: {result['error']}") | |
| except Exception as e: | |
| status_text.error(f"Error processing {uploaded_file.name}: {str(e)}") | |
| finally: | |
| # Clean up temporary file | |
| try: | |
| os.unlink(temp_path) | |
| except: | |
| pass | |
| else: | |
| status_text.info(f"{uploaded_file.name} has already been processed") | |
| progress_bar.progress(100) | |
| # Display processed files | |
| if st.session_state.processed_files: | |
| st.subheader("Processed Files") | |
| for file in sorted(st.session_state.processed_files): | |
| st.text(f"✓ {file}") | |
| def qa_page(self): | |
| st.header("Query our database") | |
| try: | |
| # Refresh available files | |
| st.session_state.processed_files = self.get_processed_files() | |
| if not st.session_state.processed_files: | |
| st.warning("No processed files available. Please upload and process some files first.") | |
| return | |
| # File selection | |
| selected_files = st.multiselect( | |
| "Select files to search through", | |
| sorted(list(st.session_state.processed_files)), | |
| default=list(st.session_state.processed_files) | |
| ) | |
| if not selected_files: | |
| st.warning("Please select at least one file to search through.") | |
| return | |
| # Question input | |
| question = st.text_input("Enter your question:") | |
| if st.button("Ask Question") and question: | |
| try: | |
| with st.spinner("Searching for answer..."): | |
| answer = st.session_state.processor.ask_question_selective( | |
| question, | |
| selected_files | |
| ) | |
| st.write("Answer:", answer) | |
| except Exception as e: | |
| st.error(f"Error getting answer: {str(e)}") | |
| except Exception as e: | |
| st.error(f"Error in Q&A interface: {str(e)}") | |
| def main(): | |
| # Initialize session state | |
| initialize_session_state() | |
| # Create and run app | |
| app = StreamlitDocProcessor() | |
| app.run() | |
| if __name__ == "__main__": | |
| main() |