Spaces:
Sleeping
Sleeping
| import asyncio | |
| import sys | |
| import hashlib | |
| import streamlit as st | |
| import pandas as pd | |
| import os | |
| import tempfile | |
| from typing import List, Optional, Dict, Any, Union | |
| import json | |
| import openai | |
| from datetime import datetime | |
| from langchain.output_parsers import PydanticOutputParser | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.schema import HumanMessage, SystemMessage | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.schema.runnable import RunnablePassthrough | |
| from langchain.prompts.prompt import PromptTemplate | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain_community.vectorstores import Chroma | |
| from pydantic import BaseModel, Field | |
| from Ingestion.ingest import process_document, get_processor_for_file | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| import warnings | |
| warnings.filterwarnings("ignore", category=RuntimeWarning) | |
| sys.path.append("../..") | |
| from dotenv import load_dotenv, find_dotenv | |
| _ = load_dotenv(find_dotenv()) | |
| openai.api_key = os.environ["OPENAI_API_KEY"] | |
| # Set event loop policy for Windows if needed | |
| if sys.platform == "win32" and sys.version_info >= (3, 8): | |
| asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="DocMind AI: AI-Powered Document Analysis", | |
| page_icon="🧠", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| # Custom CSS for better dark/light mode compatibility | |
| st.markdown(""" | |
| <style> | |
| /* Common styles for both modes */ | |
| .stApp { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| /* Card styling for results */ | |
| .card { | |
| border-radius: 5px; | |
| padding: 1.5rem; | |
| margin-bottom: 1rem; | |
| border: 1px solid rgba(128, 128, 128, 0.2); | |
| } | |
| /* Dark mode specific */ | |
| @media (prefers-color-scheme: dark) { | |
| .card { | |
| background-color: rgba(255, 255, 255, 0.05); | |
| } | |
| .highlight-container { | |
| background-color: rgba(255, 255, 255, 0.05); | |
| border-left: 3px solid #4CAF50; | |
| } | |
| .chat-user { | |
| background-color: rgba(0, 0, 0, 0.2); | |
| } | |
| .chat-ai { | |
| background-color: rgba(76, 175, 80, 0.1); | |
| } | |
| } | |
| /* Light mode specific */ | |
| @media (prefers-color-scheme: light) { | |
| .card { | |
| background-color: rgba(0, 0, 0, 0.02); | |
| } | |
| .highlight-container { | |
| background-color: rgba(0, 0, 0, 0.03); | |
| border-left: 3px solid #4CAF50; | |
| } | |
| .chat-user { | |
| background-color: rgba(240, 240, 240, 0.7); | |
| } | |
| .chat-ai { | |
| background-color: rgba(76, 175, 80, 0.05); | |
| } | |
| } | |
| /* Chat message styling */ | |
| .chat-container { | |
| margin-bottom: 1rem; | |
| } | |
| .chat-message { | |
| padding: 1rem; | |
| border-radius: 5px; | |
| margin-bottom: 0.5rem; | |
| } | |
| /* Highlight sections */ | |
| .highlight-container { | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| border-radius: 4px; | |
| } | |
| /* Status indicators */ | |
| .status-success { | |
| color: #4CAF50; | |
| } | |
| .status-error { | |
| color: #F44336; | |
| } | |
| /* Document list */ | |
| .doc-list { | |
| list-style-type: none; | |
| padding-left: 0; | |
| } | |
| .doc-list li { | |
| padding: 0.5rem 0; | |
| border-bottom: 1px solid rgba(128, 128, 128, 0.2); | |
| } | |
| /* Document card */ | |
| .doc-card { | |
| padding: 0.8rem; | |
| border-radius: 4px; | |
| border: 1px solid rgba(128, 128, 128, 0.2); | |
| margin-bottom: 0.5rem; | |
| cursor: pointer; | |
| } | |
| .doc-card:hover { | |
| background-color: rgba(76, 175, 80, 0.1); | |
| } | |
| .doc-card.selected { | |
| background-color: rgba(76, 175, 80, 0.2); | |
| border-color: #4CAF50; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Define the output structures using Pydantic | |
| class DocumentAnalysis(BaseModel): | |
| summary: str = Field(description="A concise summary of the document") | |
| key_insights: List[str] = Field(description="A list of key insights from the document") | |
| action_items: Optional[List[str]] = Field(None, description="A list of action items derived from the document") | |
| open_questions: Optional[List[str]] = Field(None, description="A list of open questions or areas needing clarification") | |
| def hash_file(file_content): | |
| """Generate SHA-256 hash of file content to check for duplicates""" | |
| return hashlib.sha256(file_content).hexdigest() | |
| class DocumentStore: | |
| def __init__(self, storage_dir="document_store"): | |
| self.storage_dir = storage_dir | |
| os.makedirs(storage_dir, exist_ok=True) | |
| self.metadata_path = os.path.join(storage_dir, "metadata.json") | |
| self.analysis_path = os.path.join(storage_dir, "analysis_results.json") | |
| self.load_metadata() | |
| self.load_analysis_results() | |
| def load_metadata(self): | |
| if os.path.exists(self.metadata_path): | |
| with open(self.metadata_path, 'r') as f: | |
| self.metadata = json.load(f) | |
| else: | |
| self.metadata = {} | |
| def load_analysis_results(self): | |
| if os.path.exists(self.analysis_path): | |
| with open(self.analysis_path, 'r') as f: | |
| self.analysis_results = json.load(f) | |
| else: | |
| self.analysis_results = {} | |
| def save_metadata(self): | |
| with open(self.metadata_path, 'w') as f: | |
| json.dump(self.metadata, f) | |
| def save_analysis_results(self): | |
| with open(self.analysis_path, 'w') as f: | |
| json.dump(self.analysis_results, f) | |
| def get_all_documents(self): | |
| """Return all documents in the store""" | |
| return self.metadata | |
| def file_exists(self, file_hash): | |
| """Check if a file with the given hash exists in the store""" | |
| return file_hash in self.metadata | |
| def get_document_path(self, file_hash): | |
| """Get the file path for a document with the given hash""" | |
| if file_hash in self.metadata: | |
| return os.path.join(self.storage_dir, file_hash) | |
| return None | |
| def add_document(self, file, file_hash): | |
| """Add a new document to the store""" | |
| # Save the file to disk | |
| file_path = os.path.join(self.storage_dir, file_hash) | |
| with open(file_path, 'wb') as f: | |
| f.write(file.getbuffer()) | |
| # Add metadata | |
| self.metadata[file_hash] = { | |
| "filename": file.name, | |
| "upload_date": datetime.now().isoformat(), | |
| "size": len(file.getbuffer()) | |
| } | |
| self.save_metadata() | |
| # Add method to store analysis results | |
| def add_analysis_result(self, doc_hash, analysis_result): | |
| """Store analysis result for a document""" | |
| if doc_hash not in self.analysis_results: | |
| self.analysis_results[doc_hash] = {} | |
| # Store with timestamp | |
| self.analysis_results[doc_hash] = { | |
| "result": analysis_result, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| self.save_analysis_results() | |
| # Add method to store combined analysis results | |
| def add_combined_analysis(self, doc_hashes, analysis_result): | |
| """Store combined analysis result for multiple documents""" | |
| session_id = "_".join(sorted(doc_hashes)) | |
| if "combined" not in self.analysis_results: | |
| self.analysis_results["combined"] = {} | |
| self.analysis_results["combined"][session_id] = { | |
| "result": analysis_result, | |
| "timestamp": datetime.now().isoformat(), | |
| "doc_hashes": doc_hashes | |
| } | |
| self.save_analysis_results() | |
| # Check if analysis exists for a document | |
| def has_analysis(self, doc_hash): | |
| return doc_hash in self.analysis_results | |
| # Check if combined analysis exists for a set of documents | |
| def has_combined_analysis(self, doc_hashes): | |
| if "combined" not in self.analysis_results: | |
| return False | |
| session_id = "_".join(sorted(doc_hashes)) | |
| return session_id in self.analysis_results["combined"] | |
| # Get analysis result for a document | |
| def get_analysis(self, doc_hash): | |
| return self.analysis_results.get(doc_hash, {}).get("result") | |
| # Get combined analysis result for multiple documents | |
| def get_combined_analysis(self, doc_hashes): | |
| if "combined" not in self.analysis_results: | |
| return None | |
| session_id = "_".join(sorted(doc_hashes)) | |
| return self.analysis_results["combined"].get(session_id, {}).get("result") | |
| # Function to clean up LLM responses for better parsing | |
| def clean_llm_response(response): | |
| """Clean up the LLM response to extract JSON content from potential markdown code blocks.""" | |
| # Extract content from the response | |
| if isinstance(response, dict) and 'choices' in response: | |
| content = response['choices'][0]['message']['content'] | |
| else: | |
| content = str(response) | |
| # Remove markdown code block formatting if present | |
| if '```' in content: | |
| # Handle ```json format | |
| parts = content.split('```') | |
| if len(parts) >= 3: # Has opening and closing backticks | |
| # Take the content between first pair of backticks | |
| content = parts[1] | |
| # Remove json language specifier if present | |
| if content.startswith('json') or content.startswith('JSON'): | |
| content = content[4:].lstrip() | |
| elif '`json' in content: | |
| # Handle `json format | |
| parts = content.split('`json') | |
| if len(parts) >= 2: | |
| content = parts[1] | |
| if '`' in content: | |
| content = content.split('`')[0] | |
| # Strip any leading/trailing whitespace | |
| content = content.strip() | |
| # Try to parse as JSON | |
| try: | |
| json_data = json.loads(content) | |
| # Check if result is nested under "properties" key | |
| if isinstance(json_data, dict) and "properties" in json_data: | |
| # Extract the properties content | |
| return json.dumps(json_data["properties"]) | |
| return content | |
| except: | |
| # If JSON parsing fails, return the original content | |
| return content | |
| # Initialize LLM without widgets in the cached function | |
| def load_model(): | |
| """Loads the language model.""" | |
| try: | |
| llm = ChatOpenAI(temperature=0.1, model_name="gpt-4o-mini") | |
| return llm | |
| except Exception as e: | |
| st.error(f"Error loading Gemini model: {e}") | |
| return None | |
| # Initialize embeddings without widgets in the cached function | |
| def load_embeddings(): | |
| """Load embeddings model""" | |
| try: | |
| embeddings = OpenAIEmbeddings(model="text-embedding-3-large") | |
| return embeddings | |
| except Exception as e: | |
| st.error(f"Error loading embeddings model: {e}") | |
| return None | |
| # Initialize session state variables | |
| if 'model_loaded' not in st.session_state: | |
| st.session_state['model_loaded'] = False | |
| if 'embeddings_loaded' not in st.session_state: | |
| st.session_state['embeddings_loaded'] = False | |
| if 'document_store' not in st.session_state: | |
| st.session_state['document_store'] = DocumentStore() | |
| if 'chat_sessions' not in st.session_state: | |
| st.session_state['chat_sessions'] = {} | |
| if 'session_history' not in st.session_state: | |
| st.session_state['session_history'] = {} | |
| if 'selected_docs' not in st.session_state: | |
| st.session_state['selected_docs'] = [] | |
| if 'analyzed_docs' not in st.session_state: | |
| st.session_state['analyzed_docs'] = set() | |
| if 'analyzed_combinations' not in st.session_state: | |
| st.session_state['analyzed_combinations'] = set() | |
| if 'active_tab' not in st.session_state: | |
| st.session_state['active_tab'] = "Upload & Manage Documents" | |
| # Sidebar Configuration with improved styling | |
| st.sidebar.markdown("<div style='text-align: center;'><h1>🧠 DocMind AI</h1></div>", unsafe_allow_html=True) | |
| st.sidebar.markdown("<div style='text-align: center;'>AI-Powered Document Analysis</div>", unsafe_allow_html=True) | |
| st.sidebar.markdown("---") | |
| # Load LLM - Only show loading spinner once | |
| with st.sidebar: | |
| if not st.session_state.get('model_loaded', False): | |
| llm = load_model() | |
| if llm: | |
| st.session_state['model_loaded'] = True | |
| else: | |
| st.session_state['model_loaded'] = False | |
| else: | |
| llm = load_model() # Will use cached version | |
| if st.session_state.get('model_loaded'): | |
| st.markdown("<div class='status-success'>✅ Model loaded successfully!</div>", unsafe_allow_html=True) | |
| else: | |
| st.markdown("<div class='status-error'>❌ Error loading model.</div>", unsafe_allow_html=True) | |
| st.stop() | |
| # Load embeddings - Only show loading spinner once | |
| with st.sidebar: | |
| if not st.session_state['embeddings_loaded']: | |
| with st.spinner("Loading embeddings..."): | |
| embeddings = load_embeddings() | |
| if embeddings: | |
| st.session_state['embeddings_loaded'] = True | |
| else: | |
| st.session_state['embeddings_loaded'] = False | |
| else: | |
| embeddings = load_embeddings() # Will use cached version | |
| if st.session_state.get('embeddings_loaded'): | |
| st.markdown("<div class='status-success'>✅ Embeddings loaded successfully!</div>", unsafe_allow_html=True) | |
| else: | |
| st.markdown("<div class='status-error'>❌ Error loading embeddings.</div>", unsafe_allow_html=True) | |
| st.stop() | |
| # Create a unique session ID for a document set | |
| def get_session_id(doc_hashes): | |
| return "_".join(sorted(doc_hashes)) | |
| # Process documents using the document store | |
| def process_documents(file_hashes): | |
| processed_docs = [] | |
| doc_store = st.session_state['document_store'] | |
| # Create a progress bar | |
| progress_bar = st.progress(0) | |
| # Use ThreadPoolExecutor for parallel processing | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| def process_single_document(file_hash, index, total): | |
| try: | |
| file_path = doc_store.get_document_path(file_hash) | |
| file_name = doc_store.metadata[file_hash]["filename"] | |
| if file_path and os.path.exists(file_path): | |
| processor = get_processor_for_file(file_path) | |
| if processor: | |
| # Process in chunks for large files | |
| doc_data = process_document_in_chunks(file_path, processor) | |
| if doc_data is not None and len(doc_data.strip()) > 0: | |
| processed_docs.append({"name": file_name, "data": doc_data, "hash": file_hash}) | |
| # Update progress | |
| progress_bar.progress((index + 1) / total) | |
| return True | |
| except Exception as e: | |
| st.error(f"Error processing {file_name}: {str(e)}") | |
| return False | |
| # Process documents in parallel | |
| total_docs = len(file_hashes) | |
| with ThreadPoolExecutor(max_workers=min(4, total_docs)) as executor: | |
| futures = {executor.submit(process_single_document, fh, i, total_docs): fh | |
| for i, fh in enumerate(file_hashes)} | |
| for future in as_completed(futures): | |
| _ = future.result() | |
| return processed_docs | |
| def process_document_in_chunks(file_path, processor, chunk_size=5*1024*1024): | |
| """Process large documents in chunks to avoid memory issues""" | |
| file_size = os.path.getsize(file_path) | |
| if file_size <= chunk_size: | |
| # For small files, process normally | |
| return processor(file_path) | |
| # For large files, especially PDFs, use a chunked approach | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| if file_ext == ".pdf": | |
| # For PDFs, process page by page | |
| return process_pdf_by_page(file_path) | |
| else: | |
| # For other large files, try to process normally but with timeout | |
| try: | |
| import signal | |
| class TimeoutException(Exception): pass | |
| def timeout_handler(signum, frame): | |
| raise TimeoutException("Processing timed out") | |
| # Set timeout of 30 seconds | |
| signal.signal(signal.SIGALRM, timeout_handler) | |
| signal.alarm(30) | |
| try: | |
| result = processor(file_path) | |
| signal.alarm(0) # Cancel the alarm | |
| return result | |
| except TimeoutException: | |
| # If timeout occurs, fall back to basic text extraction | |
| return basic_text_extraction(file_path) | |
| except: | |
| # If signal handling is not available (e.g., on Windows) | |
| return processor(file_path) | |
| # Function to set up document chat | |
| def setup_document_chat(processed_docs): | |
| doc_hashes = [doc['hash'] for doc in processed_docs] | |
| session_id = get_session_id(doc_hashes) | |
| with st.spinner("Setting up document chat..."): | |
| try: | |
| # Optimize text splitting parameters for better performance | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1500, # Larger chunks to reduce the number of embeddings | |
| chunk_overlap=150, | |
| length_function=len | |
| ) | |
| # Use a more efficient approach to create chunks | |
| all_chunks = [] | |
| for doc in processed_docs: | |
| if not doc['data'] or len(doc['data'].strip()) == 0: | |
| continue | |
| # Split the document into chunks | |
| chunks = text_splitter.split_text(doc['data']) | |
| # Add document source to each chunk but only process if chunks aren't empty | |
| if chunks: | |
| # Add document source as metadata rather than in the text to save on tokens | |
| chunks = [f"Source: {doc['name']}\n\n{chunk}" for chunk in chunks] | |
| all_chunks.extend(chunks) | |
| # If we have chunks, create the vector store | |
| if all_chunks: | |
| # Create a unique collection name based on document hashes | |
| collection_name = f"docmind_{session_id}" | |
| # Use batch processing for embeddings to improve performance | |
| vectorstore = Chroma.from_texts( | |
| texts=all_chunks, | |
| embedding=embeddings, | |
| collection_name=collection_name, | |
| collection_metadata={"timestamp": datetime.now().isoformat()} | |
| ) | |
| # Configure retriever for better performance | |
| retriever = vectorstore.as_retriever( | |
| search_kwargs={"k": 5} | |
| ) | |
| # Create a more efficient QA function | |
| def document_qa(query): | |
| # Get relevant documents | |
| docs = retriever.get_relevant_documents(query) | |
| # Extract text from documents with source highlighting | |
| context = "\n\n".join([doc.page_content for doc in docs]) | |
| # Optimize prompt for the model | |
| system_template = """You are DocMind AI, a helpful assistant that answers questions about documents. | |
| Use the following pieces of retrieved context to answer the user's question. | |
| If the answer isn't in the context, just say you don't know. | |
| Include the source document name when providing information. | |
| Context: | |
| {context} | |
| """ | |
| # Combine context and query | |
| template = ChatPromptTemplate.from_messages([ | |
| ("system", system_template), | |
| ("human", "{question}") | |
| ]) | |
| # Process with model | |
| messages = template.format_messages( | |
| context=context, | |
| question=query | |
| ) | |
| response = llm.invoke(messages) | |
| return {"answer": response} | |
| # Store the QA function in session state | |
| st.session_state['chat_sessions'][session_id] = document_qa | |
| # Initialize chat history | |
| if session_id not in st.session_state['session_history']: | |
| st.session_state['session_history'][session_id] = [] | |
| return session_id | |
| else: | |
| st.warning("No text chunks were created from the documents. Chat functionality is unavailable.") | |
| return None | |
| except Exception as e: | |
| st.error(f"Error setting up document chat: {str(e)}") | |
| return None | |
| # Main content | |
| # Get the tab options | |
| tab_options = ["Upload & Manage Documents", "Document Analysis", "Chat with Documents"] | |
| tab_index = tab_options.index(st.session_state['active_tab']) | |
| # Create the tabs with the active tab selected | |
| tab1, tab2, tab3 = st.tabs(tab_options) | |
| tabs = [tab1, tab2, tab3] | |
| active_tab = tabs[tab_index] | |
| # Tab 1: Document Upload and Management | |
| with tab1: | |
| st.header("Upload & Manage Documents") | |
| # File Upload with deduplication | |
| uploaded_files = st.file_uploader( | |
| "Upload Documents", | |
| accept_multiple_files=True, | |
| type=["pdf", "docx", "txt", "xlsx", "md", "json", "xml", "rtf", "csv", "msg", "pptx", "odt", "epub", | |
| "py", "js", "java", "ts", "tsx", "c", "cpp", "h", "html", "css", "sql", "rb", "go", "rs", "php"] | |
| ) | |
| doc_store = st.session_state['document_store'] | |
| new_files = [] | |
| existing_files = [] | |
| if uploaded_files: | |
| for file in uploaded_files: | |
| # Generate hash for the file content | |
| file_hash = hash_file(file.getbuffer()) | |
| # Check if file exists in our document store | |
| if doc_store.file_exists(file_hash): | |
| existing_files.append((file.name, file_hash)) | |
| else: | |
| # Store the file | |
| doc_store.add_document(file, file_hash) | |
| new_files.append((file.name, file_hash)) | |
| # Display information about file upload status | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if new_files: | |
| st.markdown("<div class='highlight-container'>", unsafe_allow_html=True) | |
| st.markdown("### New Documents Added") | |
| for name, file_hash in new_files: | |
| st.markdown(f"- ✅ {name}") | |
| # Automatically add to selected docs | |
| if file_hash not in st.session_state['selected_docs']: | |
| st.session_state['selected_docs'].append(file_hash) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| with col2: | |
| if existing_files: | |
| st.markdown("<div class='highlight-container'>", unsafe_allow_html=True) | |
| st.markdown("### Already Existing Documents") | |
| for name, file_hash in existing_files: | |
| st.markdown(f"- ℹ️ {name} (already in library)") | |
| # Automatically add to selected docs | |
| if file_hash not in st.session_state['selected_docs']: | |
| st.session_state['selected_docs'].append(file_hash) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Display the document library | |
| st.markdown("---") | |
| st.header("Document Library") | |
| available_docs = doc_store.get_all_documents() | |
| if available_docs: | |
| st.markdown("Select documents for analysis or chat:") | |
| # Create a grid layout for document cards | |
| cols = st.columns(3) | |
| for i, (doc_hash, doc_info) in enumerate(available_docs.items()): | |
| col_idx = i % 3 | |
| with cols[col_idx]: | |
| is_selected = doc_hash in st.session_state['selected_docs'] | |
| is_analyzed = doc_hash in st.session_state['analyzed_docs'] | |
| card_class = "doc-card selected" if is_selected else "doc-card" | |
| with st.container(): | |
| st.markdown(f"<div class='{card_class}'>", unsafe_allow_html=True) | |
| analyzed_badge = "✅ " if is_analyzed else "" | |
| st.markdown(f"**{analyzed_badge}{doc_info['filename']}**") | |
| st.markdown(f"Uploaded: {doc_info['upload_date'][:10]}") | |
| st.markdown(f"Size: {doc_info['size'] // 1024} KB") | |
| if is_analyzed: | |
| st.markdown("<span style='color:#4CAF50;font-size:0.8em;'>Analysis available</span>", unsafe_allow_html=True) | |
| if st.button("Select" if not is_selected else "Deselect", key=f"btn_{doc_hash}"): | |
| if is_selected: | |
| st.session_state['selected_docs'].remove(doc_hash) | |
| else: | |
| st.session_state['selected_docs'].append(doc_hash) | |
| st.rerun() | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Show selected documents count | |
| st.markdown("---") | |
| if st.session_state['selected_docs']: | |
| analyzed_count = sum(1 for doc_hash in st.session_state['selected_docs'] if doc_hash in st.session_state['analyzed_docs']) | |
| total_selected = len(st.session_state['selected_docs']) | |
| if analyzed_count > 0: | |
| st.success(f"{total_selected} documents selected for analysis ({analyzed_count} already analyzed)") | |
| # Add a button to jump directly to chat if all selected documents are analyzed | |
| if analyzed_count == total_selected: | |
| if st.button("Chat with selected documents"): | |
| st.session_state['active_tab'] = "Chat with Documents" | |
| st.rerun() | |
| else: | |
| st.success(f"{total_selected} documents selected for analysis") | |
| else: | |
| st.info("No documents selected. Please select documents for analysis.") | |
| else: | |
| st.info("No documents in the library. Please upload documents.") | |
| # Tab 2: Document Analysis | |
| with tab2: | |
| st.header("Document Analysis") | |
| # Mode Selection | |
| st.subheader("Analysis Configuration") | |
| analysis_mode = st.radio( | |
| "Analysis Mode", | |
| ["Analyze each document separately", "Combine analysis for all documents"] | |
| ) | |
| # Prompt Selection | |
| prompt_options = { | |
| "Comprehensive Document Analysis": "Analyze the provided document comprehensively. Generate a summary, extract key insights, identify action items, and list open questions.", | |
| "Extract Key Insights and Action Items": "Extract key insights and action items from the provided document.", | |
| "Summarize and Identify Open Questions": "Summarize the provided document and identify any open questions that need clarification.", | |
| "Custom Prompt": "Enter a custom prompt below:" | |
| } | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| selected_prompt_option = st.selectbox("Select Prompt", list(prompt_options.keys())) | |
| custom_prompt = "" | |
| if selected_prompt_option == "Custom Prompt": | |
| custom_prompt = st.text_area("Enter Custom Prompt", height=100) | |
| # Tone Selection | |
| tone_options = [ | |
| "Professional", "Academic", "Informal", "Creative", "Neutral", | |
| "Direct", "Empathetic", "Humorous", "Authoritative", "Inquisitive" | |
| ] | |
| with col2: | |
| selected_tone = st.selectbox("Select Tone", tone_options) | |
| selected_length = st.selectbox( | |
| "Select Response Format", | |
| ["Concise", "Detailed", "Comprehensive", "Bullet Points"] | |
| ) | |
| # Instructions Selection | |
| instruction_options = { | |
| "General Assistant": "Act as a helpful assistant.", | |
| "Researcher": "Act as a researcher providing in-depth analysis.", | |
| "Software Engineer": "Act as a software engineer focusing on code and technical details.", | |
| "Product Manager": "Act as a product manager considering strategy and user experience.", | |
| "Data Scientist": "Act as a data scientist emphasizing data analysis.", | |
| "Business Analyst": "Act as a business analyst considering strategic aspects.", | |
| "Technical Writer": "Act as a technical writer creating clear documentation.", | |
| "Marketing Specialist": "Act as a marketing specialist focusing on branding.", | |
| "HR Manager": "Act as an HR manager considering people aspects.", | |
| "Legal Advisor": "Act as a legal advisor providing legal perspective.", | |
| "Custom Instructions": "Enter custom instructions below:" | |
| } | |
| selected_instruction = st.selectbox("Select Assistant Behavior", list(instruction_options.keys())) | |
| custom_instruction = "" | |
| if selected_instruction == "Custom Instructions": | |
| custom_instruction = st.text_area("Enter Custom Instructions", height=100) | |
| # Display selected documents for analysis | |
| st.subheader("Selected Documents for Analysis") | |
| selected_docs = st.session_state['selected_docs'] | |
| if selected_docs: | |
| st.markdown("<ul class='doc-list'>", unsafe_allow_html=True) | |
| for doc_hash in selected_docs: | |
| if doc_hash in doc_store.metadata: | |
| st.markdown(f"<li>📄 {doc_store.metadata[doc_hash]['filename']}</li>", unsafe_allow_html=True) | |
| st.markdown("</ul>", unsafe_allow_html=True) | |
| # Create a centered button | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col2: | |
| analyze_button = st.button("Extract and Analyze Documents", use_container_width=True) | |
| # Analysis Results area placeholder | |
| analysis_results = st.container() | |
| if analyze_button: | |
| # Process the documents and run analysis | |
| with analysis_results: | |
| with st.spinner("Analyzing documents..."): | |
| processed_docs = process_documents(selected_docs) | |
| if not processed_docs: | |
| st.error("No documents could be processed. Please check the file formats and try again.") | |
| else: | |
| # Build the prompt | |
| if selected_prompt_option == "Custom Prompt": | |
| prompt_text = custom_prompt | |
| else: | |
| prompt_text = prompt_options[selected_prompt_option] | |
| if selected_instruction == "Custom Instructions": | |
| instruction_text = custom_instruction | |
| else: | |
| instruction_text = instruction_options[selected_instruction] | |
| # Add tone guidance | |
| tone_guidance = f"Use a {selected_tone.lower()} tone in your response." | |
| # Add length guidance | |
| length_guidance = "" | |
| if selected_length == "Concise": | |
| length_guidance = "Keep your response brief and to the point." | |
| elif selected_length == "Detailed": | |
| length_guidance = "Provide a detailed response with thorough explanations." | |
| elif selected_length == "Comprehensive": | |
| length_guidance = "Provide a comprehensive in-depth analysis covering all aspects." | |
| elif selected_length == "Bullet Points": | |
| length_guidance = "Format your response primarily using bullet points for clarity." | |
| # Set up the output parser | |
| output_parser = PydanticOutputParser(pydantic_object=DocumentAnalysis) | |
| format_instructions = output_parser.get_format_instructions() | |
| if analysis_mode == "Analyze each document separately": | |
| results = [] | |
| for doc in processed_docs: | |
| with st.spinner(f"Analyzing {doc['name']}..."): | |
| # Create system message with combined instructions | |
| system_message = f"{instruction_text} {tone_guidance} {length_guidance} Format your response according to these instructions: {format_instructions}" | |
| prompt = f""" | |
| {prompt_text} | |
| Document: {doc['name']} | |
| Content: {doc['data']} | |
| """ | |
| try: | |
| # Create a prompt template | |
| system_template = f"{instruction_text} {tone_guidance} {length_guidance}" | |
| messages = [ | |
| SystemMessage(content=system_template), | |
| SystemMessage(content=f"Format your response according to these instructions: {format_instructions}"), | |
| HumanMessage(content="{input}") | |
| ] | |
| template = ChatPromptTemplate.from_messages(messages) | |
| messages = template.format_messages(input=prompt) | |
| response = llm.invoke(messages) | |
| # Try to parse the response into the pydantic model | |
| try: | |
| # Clean the response before parsing | |
| cleaned_response = clean_llm_response(response) | |
| parsed_response = output_parser.parse(cleaned_response) | |
| results.append({ | |
| "document_name": doc['name'], | |
| "analysis": parsed_response.dict() | |
| }) | |
| except Exception as e: | |
| # If parsing fails, include the raw response | |
| results.append({ | |
| "document_name": doc['name'], | |
| "analysis": str(response), | |
| "parsing_error": str(e) | |
| }) | |
| except Exception as e: | |
| st.error(f"Error analyzing {doc['name']}: {str(e)}") | |
| # Display results with card-based UI | |
| for result in results: | |
| st.markdown(f"<div class='card'>", unsafe_allow_html=True) | |
| st.markdown(f"<h3>Analysis for: {result['document_name']}</h3>", unsafe_allow_html=True) | |
| if isinstance(result['analysis'], dict) and 'parsing_error' not in result: | |
| # Structured output | |
| st.markdown("<div class='highlight-container'>", unsafe_allow_html=True) | |
| st.markdown("### Summary") | |
| st.write(result['analysis']['summary']) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| st.markdown("### Key Insights") | |
| for insight in result['analysis']['key_insights']: | |
| st.markdown(f"- {insight}") | |
| if result['analysis'].get('action_items'): | |
| st.markdown("<div class='highlight-container'>", unsafe_allow_html=True) | |
| st.markdown("### Action Items") | |
| for item in result['analysis']['action_items']: | |
| st.markdown(f"- {item}") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| if result['analysis'].get('open_questions'): | |
| st.markdown("### Open Questions") | |
| for question in result['analysis']['open_questions']: | |
| st.markdown(f"- {question}") | |
| else: | |
| # Raw output | |
| st.markdown(result['analysis']) | |
| if 'parsing_error' in result: | |
| st.info(f"Note: The response could not be parsed into the expected format. Error: {result['parsing_error']}") | |
| if 'parsing_error' not in result: | |
| doc_hash = next((doc['hash'] for doc in processed_docs if doc['name'] == result['document_name']), None) | |
| if doc_hash: | |
| doc_store.add_analysis_result(doc_hash, result['analysis']) | |
| st.session_state['analyzed_docs'].add(doc_hash) | |
| if results: | |
| st.markdown("---") | |
| if st.button("Chat with these documents"): | |
| # Switch to the chat tab | |
| st.session_state['active_tab'] = "Chat with Documents" | |
| st.rerun() | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| else: # Combined analysis for all documents | |
| with st.spinner("Analyzing all documents together..."): | |
| # Combine all documents | |
| combined_docs = [] | |
| for doc in processed_docs: | |
| doc_content = f"Document: {doc['name']}\n\nContent: {doc['data']}" | |
| combined_docs.append(doc_content) | |
| combined_content = "\n\n" + "\n\n---\n\n".join(combined_docs) | |
| # Create system message with combined instructions | |
| system_message = f"{instruction_text} {tone_guidance} {length_guidance} Format your response according to these instructions: {format_instructions}" | |
| # Create the prompt template | |
| template = ChatPromptTemplate.from_messages([ | |
| ("system", system_message), | |
| ("human", "{input}") | |
| ]) | |
| # Create the prompt | |
| prompt = f""" | |
| {prompt_text} | |
| {combined_content} | |
| """ | |
| try: | |
| chain = template | llm | |
| response = chain.invoke({"input": prompt}) | |
| # Try to parse the response into the pydantic model | |
| try: | |
| cleaned_response = clean_llm_response(response) | |
| parsed_response = output_parser.parse(cleaned_response) | |
| st.markdown("<div class='card'>", unsafe_allow_html=True) | |
| st.markdown("<h3>Combined Analysis for All Documents</h3>", unsafe_allow_html=True) | |
| st.markdown("<div class='highlight-container'>", unsafe_allow_html=True) | |
| st.markdown("### Summary") | |
| st.write(parsed_response.summary) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| st.markdown("### Key Insights") | |
| for insight in parsed_response.key_insights: | |
| st.markdown(f"- {insight}") | |
| if parsed_response.action_items: | |
| st.markdown("<div class='highlight-container'>", unsafe_allow_html=True) | |
| st.markdown("### Action Items") | |
| for item in parsed_response.action_items: | |
| st.markdown(f"- {item}") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| if parsed_response.open_questions: | |
| st.markdown("### Open Questions") | |
| for question in parsed_response.open_questions: | |
| st.markdown(f"- {question}") | |
| if parsed_response: | |
| # Store the combined analysis | |
| doc_store.add_combined_analysis([doc['hash'] for doc in processed_docs], parsed_response.dict()) | |
| session_id = get_session_id([doc['hash'] for doc in processed_docs]) | |
| st.session_state['analyzed_combinations'].add(session_id) | |
| # Add button to chat with these documents | |
| st.markdown("---") | |
| if st.button("Chat with these documents"): | |
| # Switch to the chat tab | |
| st.session_state['active_tab'] = "Chat with Documents" | |
| st.rerun() | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| except Exception as e: | |
| # If parsing fails, display raw response | |
| st.markdown("<div class='card'>", unsafe_allow_html=True) | |
| st.markdown("<h3>Combined Analysis for All Documents</h3>", unsafe_allow_html=True) | |
| st.markdown(str(response)) | |
| st.info(f"Note: The response could not be parsed into the expected format. Error: {str(e)}") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(f"Error analyzing documents: {str(e)}") | |
| # Tab 3: Chat with Documents | |
| with tab3: | |
| st.header("Chat with Documents") | |
| # Display selected documents for chat | |
| st.subheader("Selected Documents") | |
| selected = st.session_state['selected_docs'] | |
| if selected: | |
| # Display selected documents | |
| st.markdown("<ul class='doc-list'>", unsafe_allow_html=True) | |
| for doc_hash in selected: | |
| if doc_hash in doc_store.metadata: | |
| doc_name = doc_store.metadata[doc_hash]["filename"] | |
| analyzed_status = "✅ (Analyzed)" if doc_hash in st.session_state['analyzed_docs'] else "📄" | |
| st.markdown(f"<li>{analyzed_status} {doc_name}</li>", unsafe_allow_html=True) | |
| st.markdown("</ul>", unsafe_allow_html=True) | |
| # Check if all documents have been analyzed | |
| all_analyzed = all(doc_hash in st.session_state['analyzed_docs'] for doc_hash in selected) | |
| session_id = get_session_id(selected) | |
| has_combined_analysis = session_id in st.session_state['analyzed_combinations'] | |
| # Show analysis results if available | |
| if has_combined_analysis: | |
| with st.expander("View Combined Analysis Results", expanded=False): | |
| combined_analysis = doc_store.get_combined_analysis(selected) | |
| if combined_analysis: | |
| # Display the combined analysis | |
| st.markdown("<div class='card'>", unsafe_allow_html=True) | |
| st.markdown("<h3>Combined Analysis for All Documents</h3>", unsafe_allow_html=True) | |
| st.markdown("<div class='highlight-container'>", unsafe_allow_html=True) | |
| st.markdown("### Summary") | |
| st.write(combined_analysis['summary']) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| st.markdown("### Key Insights") | |
| for insight in combined_analysis['key_insights']: | |
| st.markdown(f"- {insight}") | |
| if combined_analysis.get('action_items'): | |
| st.markdown("<div class='highlight-container'>", unsafe_allow_html=True) | |
| st.markdown("### Action Items") | |
| for item in combined_analysis['action_items']: | |
| st.markdown(f"- {item}") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| if combined_analysis.get('open_questions'): | |
| st.markdown("### Open Questions") | |
| for question in combined_analysis['open_questions']: | |
| st.markdown(f"- {question}") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Check if chat is already set up for these documents | |
| session_id = get_session_id(selected) | |
| if session_id not in st.session_state.get('chat_sessions', {}): | |
| # If documents have been analyzed, show a message | |
| if all_analyzed or has_combined_analysis: | |
| st.info("Documents have been analyzed. Setting up chat functionality...") | |
| # Process documents and set up chat | |
| processed_docs = process_documents(selected) | |
| if processed_docs: | |
| new_session_id = setup_document_chat(processed_docs) | |
| if new_session_id: | |
| session_id = new_session_id | |
| st.success("Chat is ready! Ask questions about your documents below.") | |
| else: | |
| st.error("Failed to set up chat for these documents.") | |
| st.stop() | |
| else: | |
| st.error("Could not process the selected documents.") | |
| st.stop() | |
| # Chat interface | |
| st.markdown("<div class='card'>", unsafe_allow_html=True) | |
| user_question = st.text_input("Ask a question about your documents:") | |
| # Use session history | |
| if session_id in st.session_state['session_history']: | |
| # Display chat history | |
| for exchange in st.session_state['session_history'][session_id]: | |
| st.markdown("<div class='chat-container'>", unsafe_allow_html=True) | |
| st.markdown(f"<div class='chat-message chat-user'><strong>You:</strong> {exchange['question']}</div>", unsafe_allow_html=True) | |
| st.markdown(f"<div class='chat-message chat-ai'><strong>DocMind AI:</strong> {exchange['answer']}</div>", unsafe_allow_html=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| if user_question: | |
| with st.spinner("Generating response..."): | |
| try: | |
| # Get the QA function for this session | |
| qa_function = st.session_state['chat_sessions'][session_id] | |
| response = qa_function(user_question) | |
| # Add to session history | |
| if session_id not in st.session_state['session_history']: | |
| st.session_state['session_history'][session_id] = [] | |
| st.session_state['session_history'][session_id].append({ | |
| "question": user_question, | |
| "answer": response['answer'] | |
| }) | |
| # Force refresh to show new message | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error generating response: {str(e)}") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Option to clear chat history | |
| if session_id in st.session_state['session_history'] and st.session_state['session_history'][session_id]: | |
| if st.button("Clear Chat History"): | |
| st.session_state['session_history'][session_id] = [] | |
| st.success("Chat history cleared!") | |
| st.rerun() | |
| else: | |
| st.info("Please select documents from the 'Upload & Manage Documents' tab first.") | |
| # Footer | |
| st.markdown("---") | |
| st.markdown( | |
| """ | |
| <div style="text-align: center"> | |
| <p>Built with ❤️ using Streamlit</p> | |
| <p>DocMind AI - AI-Powered Document Analysis</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) |