diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,1789 +1,1789 @@ -import streamlit as st -import tempfile -import os -import re -import time -from typing import List, Dict, Any, Tuple -import pickle -import hashlib -from dotenv import load_dotenv -import logging -import uuid -import base64 -from io import BytesIO -import io -import fitz # PyMuPDF for PDF image extraction -from PIL import Image -# Load environment variables from .env file -load_dotenv() - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# Document Processing -import PyPDF2 -import docx -import pandas as pd -from langchain_text_splitters import RecursiveCharacterTextSplitter -from langchain_community.document_loaders import ( - PyPDFLoader, - Docx2txtLoader, - TextLoader, - UnstructuredExcelLoader, -) - -# Vector Store -from langchain_community.vectorstores import FAISS -from langchain_core.documents import Document - -# Embeddings and LLM -from langchain_groq import ChatGroq -from langchain_community.embeddings import HuggingFaceEmbeddings -from langchain.chains import RetrievalQA, LLMChain -from langchain.prompts import PromptTemplate -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from langchain.chains.summarize import load_summarize_chain - -# For flowchart generation -import matplotlib.pyplot as plt -import networkx as nx -from matplotlib.patches import FancyArrowPatch -import numpy as np - -# Configure page -st.set_page_config( - page_title="DocMind RAG System", - page_icon="đ", - layout="wide", - initial_sidebar_state="expanded" -) - -# Initialize session state -if "vector_store" not in st.session_state: - st.session_state.vector_store = None -if "documents" not in st.session_state: - st.session_state.documents = [] -if "document_sources" not in st.session_state: - st.session_state.document_sources = {} -if "processed_docs" not in st.session_state: - st.session_state.processed_docs = [] -if "selected_model" not in st.session_state: - st.session_state.selected_model = "llama3-70b-8192" -if "chunk_size" not in st.session_state: - st.session_state.chunk_size = 250 -if "chunk_overlap" not in st.session_state: - st.session_state.chunk_overlap = 50 -if "temperature" not in st.session_state: - st.session_state.temperature = 0.1 -if "max_token_limit" not in st.session_state: - st.session_state.max_token_limit = 8192 -if "cache_dir" not in st.session_state: - st.session_state.cache_dir = os.path.join(tempfile.gettempdir(), "docmind_cache") - os.makedirs(st.session_state.cache_dir, exist_ok=True) -if "learning_style" not in st.session_state: - st.session_state.learning_style = "Standard" -# Get API keys from environment variables -GROQ_API_KEY = os.getenv("GROQ_API_KEY") -if not GROQ_API_KEY: - st.sidebar.error("GROQ_API_KEY not found in environment variables. Please set it in your .env file.") - logger.error("GROQ_API_KEY not found in environment variables") - -# Create a sidebar for configuration options -st.sidebar.title("DocMind RAG System") - -# Model Configuration -with st.sidebar.expander("đ¤ Model Configuration", expanded=False): - model_options = { - "llama3-70b-8192": "Llama-3-70b-8192", - "llama3-8b-8192": "Llama-3-8b-8192" - } - - st.session_state.selected_model = st.selectbox( - "Select Model", - options=list(model_options.keys()), - format_func=lambda x: model_options[x], - index=list(model_options.keys()).index(st.session_state.selected_model) - ) - st.session_state.chunk_size = st.slider( - "Chunk Size", - min_value=100, - max_value=2000, - value=st.session_state.chunk_size, - step=100, - help="Size of text chunks for processing" - ) - - st.session_state.chunk_overlap = st.slider( - "Chunk Overlap", - min_value=0, - max_value=500, - value=st.session_state.chunk_overlap, - step=50, - help="Overlap between consecutive chunks" - ) - - st.session_state.temperature = st.slider( - "Temperature", - min_value=0.0, - max_value=1.0, - value=st.session_state.temperature, - step=0.1, - help="Controls randomness in responses (0=deterministic, 1=creative)" - ) - -# Sidebar footer -st.sidebar.markdown("---") -st.sidebar.info("đ DocumentsMind RAG System") -st.sidebar.markdown("Powered by Groq & LangChain") -# Add this function to extract images from PDFs -def extract_images_with_context(file_path): - """Extract images from PDF files along with surrounding text context""" - images = [] - try: - pdf_document = fitz.open(file_path) - for page_num in range(len(pdf_document)): - page = pdf_document[page_num] - # Get page text for context - page_text = page.get_text() - - # Extract images - image_list = page.get_images(full=True) - for img_index, img_info in enumerate(image_list): - xref = img_info[0] - base_image = pdf_document.extract_image(xref) - image_bytes = base_image["image"] - image_format = base_image.get("ext", "png").lower() - - try: - image = Image.open(io.BytesIO(image_bytes)) - if image.width > 100 and image.height > 100: - # Get image rectangle on page - img_rect = page.get_image_bbox(img_info) - - # Extract text near the image (context) - # Get text before and after the image within a certain range - context_text = extract_text_around_rect(page, img_rect, radius=200) - - buffered = io.BytesIO() - image.save(buffered, format=image_format.upper()) - img_str = base64.b64encode(buffered.getvalue()).decode() - - images.append({ - "data": img_str, - "format": image_format, - "page": page_num + 1, - "width": image.width, - "height": image.height, - "context": context_text, # Store context with image - "page_text": page_text # Store full page text for broader context - }) - except Exception as e: - print(f"Error processing image: {e}") - continue - print(f"Extracted {len(images)} images with context from {file_path}") - return images - except Exception as e: - print(f"Error extracting images: {e}") - return [] - -def extract_text_around_rect(page, rect, radius=200): - """Extract text around a rectangle on a page""" - # Expand the rectangle by the radius - expanded_rect = ( - max(0, rect[0] - radius), - max(0, rect[1] - radius), - min(page.rect.width, rect[2] + radius), - min(page.rect.height, rect[3] + radius) - ) - - # Get text in the expanded rectangle - return page.get_text("text", clip=expanded_rect) - -# Add this function to find semantically relevant images -def find_relevant_images(query, document_images, embeddings): - """Find images that are semantically relevant to the query""" - if not document_images: - return [] - - relevant_images = [] - - # Create query embedding - query_embedding = embeddings.embed_query(query) - - # For each document with images - for doc_name, images in document_images.items(): - for img in images: - if "context" in img: - # Create embedding for image context - try: - context_embedding = embeddings.embed_query(img["context"]) - - # Calculate cosine similarity - similarity = cosine_similarity( - [query_embedding], - [context_embedding] - )[0][0] - - # If similarity exceeds threshold, consider relevant - if similarity > 0.3: # Adjust threshold as needed - relevant_images.append({ - **img, - "source": doc_name, - "similarity": similarity - }) - except Exception as e: - print(f"Error calculating similarity: {e}") - - # Sort by relevance - relevant_images.sort(key=lambda x: x.get("similarity", 0), reverse=True) - - # Return top results (limit to avoid too many images) - return relevant_images[:3] - -# Add this function to store images in session state -def process_document_with_images(file, progress_bar=None): - """Process a document and extract text and images""" - # Save uploaded file temporarily - with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[1]) as tmp_file: - tmp_file.write(file.getbuffer()) - tmp_path = tmp_file.name - - try: - # Get document loader - loader = get_loader_for_file(tmp_path) - - # Load document - docs = loader.load() - - # Update progress if provided - if progress_bar: - progress_bar.progress(0.3) - - # Extract images if PDF - images = [] - if tmp_path.lower().endswith('.pdf'): - images = extract_images_with_context(tmp_path) - - # Update progress if provided - if progress_bar: - progress_bar.progress(0.6) - - # Split text into chunks - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=st.session_state.chunk_size, - chunk_overlap=st.session_state.chunk_overlap, - separators=["\n\n", "\n", " ", ""] - ) - - chunks = text_splitter.split_documents(docs) - - # Store source information for each chunk - for chunk in chunks: - if "source" not in chunk.metadata: - chunk.metadata["source"] = file.name - - # Update progress if provided - if progress_bar: - progress_bar.progress(1.0) - - return chunks, images - - except Exception as e: - st.error(f"Error processing {file.name}: {str(e)}") - return None, [] - finally: - # Clean up temporary file - try: - os.unlink(tmp_path) - except: - pass - -# Helper Functions -def get_document_hash(file_bytes: bytes) -> str: - """Generate a hash for a document to use as a unique identifier""" - return hashlib.md5(file_bytes).hexdigest() - - -def extract_text_from_pdf(file_path: str) -> str: - """Extract text from PDF files""" - with open(file_path, 'rb') as file: - pdf_reader = PyPDF2.PdfReader(file) - text = "" - for page_num in range(len(pdf_reader.pages)): - page = pdf_reader.pages[page_num] - text += page.extract_text() + "\n" - return text - - -def extract_text_from_docx(file_path: str) -> str: - """Extract text from DOCX files""" - doc = docx.Document(file_path) - text = "" - for paragraph in doc.paragraphs: - text += paragraph.text + "\n" - return text - - -def extract_text_from_txt(file_path: str) -> str: - """Extract text from TXT files""" - with open(file_path, 'r', encoding='utf-8') as file: - return file.read() - - -def get_loader_for_file(file_path: str): - """Return the appropriate document loader based on file extension""" - file_extension = os.path.splitext(file_path)[1].lower() - - if file_extension == '.pdf': - return PyPDFLoader(file_path) - elif file_extension == '.docx': - return Docx2txtLoader(file_path) - elif file_extension == '.txt': - return TextLoader(file_path) - else: - raise ValueError(f"Unsupported file type: {file_extension}") - -def process_document_with_improved_images(file, progress_bar=None): - """Process a document and extract text and images with context""" - # Save uploaded file temporarily - with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[1]) as tmp_file: - tmp_file.write(file.getbuffer()) - tmp_path = tmp_file.name - - try: - # Get document loader - loader = get_loader_for_file(tmp_path) - - # Load document - docs = loader.load() - - # Update progress if provided - if progress_bar: - progress_bar.progress(0.3) - - # Extract images with context if PDF - images = [] - if tmp_path.lower().endswith('.pdf'): - images = extract_images_with_context(tmp_path) - - # Update progress if provided - if progress_bar: - progress_bar.progress(0.6) - - # Split text into chunks - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=st.session_state.chunk_size, - chunk_overlap=st.session_state.chunk_overlap, - separators=["\n\n", "\n", " ", ""] - ) - - chunks = text_splitter.split_documents(docs) - - # Store source information for each chunk - for chunk in chunks: - if "source" not in chunk.metadata: - chunk.metadata["source"] = file.name - - # Update progress if provided - if progress_bar: - progress_bar.progress(1.0) - - return chunks, images - - except Exception as e: - st.error(f"Error processing {file.name}: {str(e)}") - return None, [] - finally: - # Clean up temporary file - try: - os.unlink(tmp_path) - except: - pass -def save_vector_store(vector_store, name): - """Save the vector store to disk for persistence""" - cache_path = os.path.join(st.session_state.cache_dir, f"{name}.pkl") - with open(cache_path, "wb") as f: - pickle.dump(vector_store, f) - - -def load_vector_store(name): - """Load a vector store from disk if it exists""" - cache_path = os.path.join(st.session_state.cache_dir, f"{name}.pkl") - if os.path.exists(cache_path): - with open(cache_path, "rb") as f: - return pickle.load(f) - return None - - -def initialize_llm(): - """Initialize the language model with current settings""" - if not GROQ_API_KEY: - st.error("GROQ_API_KEY not found in environment variables. Please set it in your .env file.") - return None - - try: - return ChatGroq( - groq_api_key=GROQ_API_KEY, - model_name=st.session_state.selected_model, - temperature=st.session_state.temperature, - max_tokens=st.session_state.max_token_limit, - streaming=True, - callbacks=[StreamingStdOutCallbackHandler()] - ) - except Exception as e: - st.error(f"Error initializing language model: {str(e)}") - logger.error(f"Error initializing language model: {str(e)}") - return None - - -def initialize_embeddings(): - """Initialize embedding model""" - try: - # Using HuggingFace embeddings as Groq doesn't provide embedding API - return HuggingFaceEmbeddings( - model_name="sentence-transformers/all-MiniLM-L6-v2", - cache_folder=os.path.join(st.session_state.cache_dir, "hf_models") - ) - except Exception as e: - st.error(f"Error initializing embeddings: {str(e)}") - logger.error(f"Error initializing embeddings: {str(e)}") - return None - - -# Improved get_learning_style_prompt function to ensure visual elements for visual learners -def get_learning_style_prompt(learning_style): - """Return prompt instructions based on learning style""" - styles = { - "Visual learner": """ - FORMAT YOUR ANSWER FOR A VISUAL LEARNER: - - Always include a Markdown header: "## Visual Learner Format" - - Use descriptive language that creates mental images - - Structure information spatially (use tables, diagrams when possible) - - Suggest visual metaphors for complex concepts - - Include a Mermaid.js diagram to visualize the key concepts - - Always include at least one table to organize the information - - Create a flowchart for process-related information - - Example visual elements you MUST include: - - ```mermaid - graph TD - A[Concept] --> B[Related Concept] - B --> C[Example] - A --> D[Another Related Concept] - ``` - - | Concept | Description | Example | - |---------|-------------|---------| - | Key Idea 1 | Description 1 | Example 1 | - | Key Idea 2 | Description 2 | Example 2 | - - DO NOT omit the table and diagram - they are essential for visual learning. - """, - - "Auditory learner": """ - FORMAT YOUR ANSWER FOR AN AUDITORY LEARNER: - - Always include a Markdown header: "## Auditory Learner Format" - - Use conversational language and rhetorical questions - - Include examples of how you'd explain this verbally - - Suggest mnemonics or rhymes to remember key points - - Use a dialogue format where appropriate (Q&A) - - Repeat key points in different ways - """, - - "Reading/writing learner": """ - Format your answer for a reading/writing learner: - 1. Use precise and detailed textual explanations - 2. Include well-structured paragraphs with topic sentences - 3. Provide detailed lists and definitions - 4. Use quotes and references when appropriate - 5. Include clear headings, subheadings, and a logical progression of ideas - 6. Emphasize key points through repetition in written form - """, - - "Kinesthetic learner": """ - Format your answer for a kinesthetic learner: - 1. Focus on practical applications and real-world examples - 2. Include step-by-step processes and procedures - 3. Relate concepts to physical actions or sensations - 4. Use case studies and scenarios that involve doing or experiencing - 5. Include interactive elements or suggestions for hands-on activities - 6. Break information into actionable chunks with clear sequences - """ - } - - return styles.get(learning_style, "") -# Enhanced function to perform QA with visual elements for visual learners -def perform_qa_with_images(question, vector_store, k=4, learning_style="Standard"): - """Perform question answering using RAG with learning style adaptation and semantic image search""" - llm = initialize_llm() - if not llm: - return None - - # Initialize embeddings for semantic search - embeddings = initialize_embeddings() - if not embeddings: - return None - - retriever = vector_store.as_retriever(search_kwargs={"k": k}) - - learning_style_instructions = get_learning_style_prompt(learning_style) - - # Add specific instructions for visual elements when visual learner is selected - visual_instructions = "" - if learning_style == "Visual learner": - visual_instructions = """ - You MUST include these visual elements in your response: - 1. A table summarizing key points - 2. Use of spatial organization for information - - Do not mention that you're including these elements because of instructions - just include them naturally. - """ - - # Enhanced RAG prompt with better context integration - qa_template = f""" - You are a helpful assistant that answers questions based on provided context. - - Context: - {{context}} - - Question: {{question}} - - Instructions: - 1. Answer the question based ONLY on the provided context. - 2. If you don't know the answer based on the context, say "I don't have enough information to answer this question." - 3. Provide specific references to the source documents when possible. - 4. Be concise but thorough in your response. - 5. Format your answer in a clear, readable way with appropriate headings, bullet points, and structure. - 6. If the question is about a general topic not specifically covered in the documents, synthesize information from the context to provide a helpful response. - 7. If the question requires reasoning beyond what's in the documents, use the provided context as a foundation and clearly indicate when you're making inferences. - - {learning_style_instructions} - {visual_instructions} - - Answer: - """ - - QA_PROMPT = PromptTemplate( - template=qa_template, - input_variables=["context", "question"] - ) - - chain_type_kwargs = {"prompt": QA_PROMPT} - - qa = RetrievalQA.from_chain_type( - llm=llm, - chain_type="stuff", - retriever=retriever, - chain_type_kwargs=chain_type_kwargs, - return_source_documents=True - ) - - result = qa({"query": question}) - - # Extract source information - sources = [] - source_documents = [] - try: - for doc in result.get("source_documents", []): - source_documents.append(doc) - if "source" in doc.metadata: - if doc.metadata["source"] not in sources: - sources.append(doc.metadata["source"]) - except: - pass - - # Find semantically relevant images - relevant_images = [] - if "document_images" in st.session_state: - # Find images related to the query semantically - query_relevant_images = find_relevant_images( - question, - st.session_state.document_images, - embeddings - ) - relevant_images.extend(query_relevant_images) - - # Also include images from source documents - for source in sources: - if source in st.session_state.document_images: - # Filter to most relevant images from this source - source_images = st.session_state.document_images[source] - if source_images: - # Get up to 2 images from each source document - relevant_images.extend(source_images[:2]) - - # Always initialize flowchart_result to None - flowchart_result = None - - # Generate a flowchart for visual learners - if learning_style == "Visual learner": - try: - flowchart_result = generate_hierarchical_flowchart(vector_store, question) - except Exception as e: - logger.error(f"Error generating flowchart: {str(e)}") - flowchart_result = None - - # Remove duplicates by checking data strings - unique_images = [] - unique_data_strings = set() - for img in relevant_images: - if img["data"][:50] not in unique_data_strings: # Use start of data as identifier - unique_data_strings.add(img["data"][:50]) - unique_images.append(img) - - return { - "answer": result.get("result", "No answer generated"), - "sources": sources, - "relevant_images": unique_images, - "flowchart": flowchart_result - } - - -def extract_text_from_pdf(file_path: str) -> str: - """Extract text from PDF files""" - with open(file_path, 'rb') as file: - pdf_reader = PyPDF2.PdfReader(file) - text = "" - for page_num in range(len(pdf_reader.pages)): - page = pdf_reader.pages[page_num] - text += page.extract_text() + "\n" - return text -# Enhanced generate_summary function for visual learners -def generate_summary(vector_store, learning_style: str = "Standard"): - """Generate a summary of the documents with learning style adaptation""" - llm = initialize_llm() - if not llm: - return None - - # Get documents from vector store - docs = vector_store.similarity_search("", k=10) - - # Limit number of documents for summary to avoid token limit issues - if len(docs) > 10: - docs = docs[:10] - - learning_style_instructions = get_learning_style_prompt(learning_style) - - # Add specific instructions for visual elements when visual learner is selected - visual_instructions = "" - if learning_style == "Visual learner": - visual_instructions = """ - You MUST include these visual elements in your response: - 1. A table summarizing key points from the documents - 2. A mermaid flowchart showing the relationship between main concepts - 3. Use of spatial organization for information - - Do not mention that you're including these elements because of instructions - just include them naturally. - """ - - # Use map_reduce for more efficient summarization of large documents - map_prompt_template = """ - Write a concise summary of the following text: - "{text}" - CONCISE SUMMARY: - """ - MAP_PROMPT = PromptTemplate(template=map_prompt_template, input_variables=["text"]) - - combine_prompt_template = f""" - You are a helpful assistant that combines multiple document summaries into a comprehensive overview. - - Below are summaries of different sections of a document or multiple documents: - {{text}} - - Create a final comprehensive summary that captures the key information from all these summaries. - - {learning_style_instructions} - {visual_instructions} - - The summary should be well-organized, coherent, and highlight the most important points. - Use appropriate headings, bullet points, and structure to make the summary easy to read and understand. - - COMPREHENSIVE SUMMARY: - """ - COMBINE_PROMPT = PromptTemplate(template=combine_prompt_template, input_variables=["text"]) - - summary_chain = load_summarize_chain( - llm, - chain_type="map_reduce", - map_prompt=MAP_PROMPT, - combine_prompt=COMBINE_PROMPT, - verbose=True - ) - - summary = summary_chain.run(docs) - - # Extract source information - sources = [] - for doc in docs: - if "source" in doc.metadata: - if doc.metadata["source"] not in sources: - sources.append(doc.metadata["source"]) - - # Generate flowchart for visual learners - flowchart_result = None - if learning_style == "Visual learner": - # Add the function call to generate a flowchart - try: - flowchart_result = generate_hierarchical_flowchart(vector_store, "Create a concept map of the main ideas in these documents") - except Exception as e: - logger.error(f"Error generating flowchart: {str(e)}") - flowchart_result = None - - # Extract key topics for visual learners to create a table visualization - table_data = None - if learning_style == "Visual learner": - try: - keywords = extract_keywords(vector_store, num_keywords=8) - table_data = { - "keywords": keywords, - "explanation": "Key topics extracted from the documents" - } - except Exception as e: - logger.error(f"Error extracting keywords: {str(e)}") - table_data = None - - return { - "summary": summary, - "sources": sources, - "flowchart": flowchart_result, - "table_data": table_data - } - -# Add this function to display the analysis results properly -def display_analysis_results(result, analysis_type, learning_style): - """Display analysis results with appropriate visual elements based on learning style""" - st.markdown(f"### {analysis_type} Results") - st.markdown(result["result"]) - - # For visual learners, always show the flowchart - if learning_style == "Visual learner" and "flowchart" in result and result["flowchart"]: - st.markdown("### Visual Representation") - st.image(f"data:image/png;base64,{result['flowchart']['visual']}", - caption=result["flowchart"].get("title", f"Visual representation of {analysis_type}"), - use_column_width=True) - -def generate_flowchart(vector_store): - """Generate a flowchart based on document content or extract existing diagrams""" - llm = initialize_llm() - if not llm: - return None - - # First check if any documents contain diagrams/images - docs = vector_store.similarity_search("diagram OR image OR figure OR chart", k=5) - - # Check if any documents mention diagrams - has_diagrams = any( - "diagram" in doc.page_content.lower() or - "image" in doc.page_content.lower() or - "figure" in doc.page_content.lower() or - "chart" in doc.page_content.lower() - for doc in docs - ) - - if has_diagrams: - # Extract information about existing diagrams - diagram_prompt = PromptTemplate( - input_variables=["text"], - template=""" - Analyze the following text and identify any diagrams, images, or figures that are referenced. - For each visual element found, provide: - 1. Its title or description - 2. The page/section where it appears - 3. A brief explanation of what it shows - - If no diagrams are found, say "No diagrams found in the documents." - - Text: - {text} - - Diagram Analysis: - """ - ) - - chain = LLMChain(llm=llm, prompt=diagram_prompt) - result = chain.run(text="\n\n".join([doc.page_content for doc in docs])) - - return { - "type": "description", - "content": result, - "visual": None - } - else: - # No diagrams found, create a new flowchart - combined_text = "\n\n".join([doc.page_content for doc in docs]) - - flowchart_prompt = PromptTemplate( - input_variables=["text"], - template=""" - Analyze the following text and identify key concepts, processes, or entities that could be organized into a flowchart. - - Text: - {text} - - Please extract: - 1. A list of 5-10 key nodes (main concepts) - 2. A descriptive title for the flowchart - - Format your response as JSON: - {{ - "title": "The flowchart title", - "nodes": ["Node1", "Node2", "Node3", ...], - "connections": [ - ["Node1", "Node2", "description of connection"], - ["Node2", "Node3", "description of connection"], - ... - ] - }} - - Only provide the JSON with no other text or explanation. - """ - ) - - chain = LLMChain(llm=llm, prompt=flowchart_prompt) - result = chain.run(text=combined_text) - - try: - import json - flowchart_data = json.loads(result) - img_str = create_flowchart_image(flowchart_data) - - return { - "type": "flowchart", - "content": flowchart_data, - "visual": img_str - } - except Exception as e: - st.error(f"Error creating flowchart: {str(e)}") - return None - -def create_flowchart_image(flowchart_data): - """Create a flowchart image from structured data""" - try: - # Create a directed graph - G = nx.DiGraph() - - # Add nodes - for node in flowchart_data["nodes"]: - G.add_node(node) - - # Add edges - for connection in flowchart_data["connections"]: - from_node, to_node, label = connection - G.add_edge(from_node, to_node, label=label) - - # Create figure - plt.figure(figsize=(12, 8)) - plt.title(flowchart_data["title"], fontsize=16) - - # Set up the layout - pos = nx.spring_layout(G, seed=42) - - # Draw nodes - nx.draw_networkx_nodes(G, pos, node_size=2000, node_color="lightblue", alpha=0.8) - - # Draw edges - curved_edges = [edge for edge in G.edges()] - nx.draw_networkx_edges( - G, pos, edgelist=curved_edges, width=2, alpha=0.5, - connectionstyle="arc3,rad=0.1", arrowsize=20 - ) - - # Draw node labels - nx.draw_networkx_labels(G, pos, font_size=12, font_family="sans-serif") - - # Draw edge labels - edge_labels = nx.get_edge_attributes(G, "label") - nx.draw_networkx_edge_labels( - G, pos, edge_labels=edge_labels, font_size=10, - bbox=dict(facecolor="white", edgecolor="none", alpha=0.7), - rotate=False, label_pos=0.5 - ) - - plt.axis("off") - plt.tight_layout() - - # Save to BytesIO - buf = BytesIO() - plt.savefig(buf, format="png", bbox_inches="tight") - plt.close() - buf.seek(0) - - # Convert to base64 - img_str = base64.b64encode(buf.read()).decode() - - return img_str - except Exception as e: - st.error(f"Error generating flowchart: {str(e)}") - return None - - -import re -def extract_keywords(vector_store, num_keywords=10): - """Extract key topics/keywords from the documents""" - llm = initialize_llm() - if not llm: - return None - - # Get a sample of documents from vector store - docs = vector_store.similarity_search("", k=5) - - # Combine document contents - combined_text = "\n\n".join([doc.page_content for doc in docs]) - - # Create prompt for keyword extraction with clear formatting instructions - keyword_prompt = PromptTemplate( - input_variables=["text", "num_keywords"], - template=""" - Extract exactly {num_keywords} important keywords or key phrases from the following text. - Return ONLY the keywords, one per line, with no numbering, explanations, or additional text. - - TEXT: {text} - KEYWORDS: - """ - ) - - chain = LLMChain(llm=llm, prompt=keyword_prompt) - result = chain.run(text=combined_text, num_keywords=num_keywords) - - # Clean and format the keywords - cleaned_result = result.strip() - - # Remove common prefixes and explanatory text - unwanted_patterns = [ - "Here are the", "most important", "keywords or key phrases", "extracted from the text:", - "separated by commas:", "The", "keywords are:", "Here's the list of", "1.", "2.", "3.", - "âĸ", "*", "-", "KEYWORDS:", "Key topics:" - ] - - for pattern in unwanted_patterns: - cleaned_result = cleaned_result.replace(pattern, "") - - # Split by common delimiters (newlines, commas) and clean each item - if '\n' in cleaned_result: - keywords = [k.strip() for k in cleaned_result.split('\n') if k.strip()] - else: - keywords = [k.strip() for k in cleaned_result.split(',') if k.strip()] - - # Further clean any numbering or bullets that might remain - keywords = [re.sub(r'^\d+\.\s*', '', k).strip() for k in keywords] - - # Ensure each keyword is 1-3 words maximum - filtered_keywords = [] - for k in keywords: - words = k.split() - if len(words) <= 10: - filtered_keywords.append(k) - else: - filtered_keywords.append(' '.join(words[:3])) - - # Remove duplicates while preserving order - seen = set() - unique_keywords = [k for k in filtered_keywords if not (k in seen or seen.add(k))] - - # Exclude the first keyword and first topic - if len(unique_keywords) > 1: - return unique_keywords[1:] # Return all except the first keyword - return unique_keywords # Return the list if it has one or no keywords - - - # Ensure we return exactly num_keywords if possible - return unique_keywords[:num_keywords] -def analyze_sentiment(vector_store): - """Analyze the overall sentiment of the documents""" - llm = initialize_llm() - if not llm: - return None - - # Get a sample of documents from vector store - docs = vector_store.similarity_search("", k=5) - - # Combine document contents - combined_text = "\n\n".join([doc.page_content for doc in docs]) - - # Create prompt for sentiment analysis - sentiment_prompt = PromptTemplate( - input_variables=["text"], - template=""" - Analyze the sentiment of the following text. Determine if it is: - 1. Strongly Positive - 2. Positive - 3. Neutral - 4. Negative - 5. Strongly Negative - - Also provide a brief explanation of your assessment. - - TEXT: - {text} - - SENTIMENT ANALYSIS: - """ - ) - - chain = LLMChain(llm=llm, prompt=sentiment_prompt) - result = chain.run(text=combined_text) - - return result - -# Add this new function to generate proper hierarchical flowcharts -def generate_hierarchical_flowchart(vector_store, question): - """Generate a hierarchical flowchart based on document content and question""" - llm = initialize_llm() - if not llm: - return None - - # Get related documents - docs = vector_store.similarity_search(question, k=5) - - # Combine document contents - combined_text = "\n\n".join([doc.page_content for doc in docs]) - - flowchart_prompt = PromptTemplate( - input_variables=["text", "question"], - template=""" - Analyze the following text and the user's question. Create a hierarchical flowchart or concept map showing the key concepts and their relationships. - - User question: {question} - - Text: - {text} - - Create a hierarchical flowchart with 5-10 nodes that effectively visualizes the main concepts related to the question. - - Format your response as JSON: - {{ - "title": "A clear title for the flowchart", - "nodes": ["Main Concept", "Sub-concept 1", "Sub-concept 2", ...], - "connections": [ - ["Main Concept", "Sub-concept 1", "relationship"], - ["Main Concept", "Sub-concept 2", "relationship"], - ["Sub-concept 1", "Sub-concept 1.1", "relationship"], - ... - ], - "layout": "hierarchical" - }} - - Only provide the JSON with no other text or explanation. - """ - ) - - chain = LLMChain(llm=llm, prompt=flowchart_prompt) - result = chain.run(text=combined_text, question=question) - - try: - import json - flowchart_data = json.loads(result) - - # Create a directed graph - G = nx.DiGraph() - - # Add nodes - for node in flowchart_data["nodes"]: - G.add_node(node) - - # Add edges - for connection in flowchart_data["connections"]: - from_node, to_node, label = connection - G.add_edge(from_node, to_node, label=label) - - # Create figure - plt.figure(figsize=(12, 8)) - plt.title(flowchart_data["title"], fontsize=16) - - # Use hierarchical layout - pos = nx.nx_pydot.graphviz_layout(G, prog="dot") - - # Draw nodes - nx.draw_networkx_nodes(G, pos, node_size=2000, node_color="lightblue", alpha=0.8) - - # Draw edges - nx.draw_networkx_edges(G, pos, width=2, alpha=0.5, arrowsize=20) - - # Draw node labels - nx.draw_networkx_labels(G, pos, font_size=10, font_family="sans-serif") - - # Draw edge labels - edge_labels = nx.get_edge_attributes(G, "label") - nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8) - - plt.axis("off") - plt.tight_layout() - - # Save to BytesIO - buf = BytesIO() - plt.savefig(buf, format="png", dpi=300, bbox_inches="tight") - plt.close() - buf.seek(0) - - # Convert to base64 - img_str = base64.b64encode(buf.read()).decode() - - return {"visual": img_str, "title": flowchart_data["title"]} - except Exception as e: - print(f"Error generating flowchart: {str(e)}") - return None -# Main Navigation Menu using native Streamlit components -tabs = st.tabs(["đ Upload", "đ Q&A", "đ Summarize", "đ Analyze"]) - -# Upload Documents Page -with tabs[0]: - st.title("đ Upload Documents") - st.markdown("Upload your documents for processing and analysis. Supported formats: PDF, DOCX, TXT") - - uploaded_files = st.file_uploader( - "Choose files", - type=["pdf", "docx", "txt"], - accept_multiple_files=True - ) - - col1, col2 = st.columns([1, 1]) - - with col1: - st.info(f"Current Settings:\n- Chunk Size: {st.session_state.chunk_size}\n- Chunk Overlap: {st.session_state.chunk_overlap}") - - with col2: - clear_button = st.button("Clear All Documents") - if clear_button: - st.session_state.documents = [] - st.session_state.vector_store = None - st.session_state.document_sources = {} - st.session_state.processed_docs = [] - st.success("All documents cleared!") - - if uploaded_files: - process_button = st.button("Process Documents") - - if process_button: - with st.spinner("Processing documents..."): - # Initialize embeddings - embeddings = initialize_embeddings() - if not embeddings: - st.error("Failed to initialize embeddings. Please check your API key.") - else: - # Process each document - new_docs = [] - - # Initialize image storage if not present - if "document_images" not in st.session_state: - st.session_state.document_images = {} - - # Set up progress tracking - progress_text = st.empty() - progress_bar = st.progress(0) - - for i, file in enumerate(uploaded_files): - # Check if we've already processed this file - file_hash = get_document_hash(file.getvalue()) - - if file_hash in st.session_state.processed_docs: - progress_text.text(f"Skipping already processed file: {file.name}") - continue - - progress_text.text(f"Processing {file.name} ({i+1}/{len(uploaded_files)})") - doc_progress = st.progress(0) - - chunks, images = process_document_with_images(file, progress_bar=doc_progress) - - if chunks: - new_docs.extend(chunks) - st.session_state.document_sources[file.name] = len(chunks) - st.session_state.processed_docs.append(file_hash) - - # Store images if any were found - if images: - st.session_state.document_images[file.name] = images - - progress_bar.progress((i + 1) / len(uploaded_files)) - - if new_docs: - # Generate unique IDs for each document - texts = [doc.page_content for doc in new_docs] - metadatas = [doc.metadata for doc in new_docs] - - # Create new vector store if none exists - if st.session_state.vector_store is None: - # Generate unique IDs for new documents - ids = [str(uuid.uuid4()) for _ in range(len(new_docs))] - st.session_state.vector_store = FAISS.from_texts(texts, embeddings, metadatas=metadatas, ids=ids) - else: - # Add new documents to existing vector store with unique IDs - ids = [str(uuid.uuid4()) for _ in range(len(new_docs))] - st.session_state.vector_store.add_texts(texts, metadatas=metadatas, ids=ids) - - progress_text.text("Creating vector index...") - - # Save vector store - save_vector_store(st.session_state.vector_store, "docmind_store") - - st.session_state.documents.extend(new_docs) - st.success(f"Successfully processed {len(new_docs)} new document chunks!") - else: - st.warning("No new documents to process.") - - # Clear progress indicators - progress_text.empty() - progress_bar.empty() - - - # Display status of processed documents - if st.session_state.document_sources: - st.subheader("Processed Documents") - - # Create summary table - document_data = [] - for doc_name, chunk_count in st.session_state.document_sources.items(): - document_data.append({"Document": doc_name, "Chunks": chunk_count}) - - # Display as DataFrame - st.dataframe(pd.DataFrame(document_data)) - - total_chunks = sum(st.session_state.document_sources.values()) - st.info(f"Total document chunks in memory: {total_chunks}") - - else: - st.write("No documents processed yet. Please upload and process some documents.") - - -# Then in your chat processing, utilize the memory -def process_chat_with_rag(user_input, vector_store, memory): - llm = initialize_llm() - retriever = vector_store.as_retriever(search_kwargs={"k": 4}) - - qa = ConversationalRetrievalChain.from_llm( - llm=llm, - retriever=retriever, - memory=memory - ) - - result = qa({"question": user_input}) - return result -# Q&A Page -with tabs[1]: - st.title("đ Question & Answer") - - if not st.session_state.vector_store: - # Try to load from cache - st.session_state.vector_store = load_vector_store("docmind_store") - - if not st.session_state.vector_store: - st.warning("No documents have been processed. Please go to the Upload tab first.") - else: - st.markdown("Ask a question about your documents and get answers based on their content.") - - # Learning style selector - learning_styles = ["Standard", "Visual learner", "Auditory learner", "Reading/writing learner", "Kinesthetic learner"] - selected_style = st.radio("Select your learning style for the answer:", learning_styles, horizontal=True) - # Replace with enhanced input that encourages natural language questions: - st.markdown("Ask me anything about your documents.") - question = st.text_area("Enter your question:", height=100, placeholder="Ask me anything about your documents. I can answer based on the content while providing additional context.") - - col1, col2, col3 = st.columns([1, 1, 1]) - - with col1: - k_value = st.slider( - "Number of relevant chunks to retrieve", - min_value=1, - max_value=10, - value=4, - step=1, - help="Higher values may provide more comprehensive answers but can introduce noise" - ) - - with col2: - submit_button = st.button("Submit Question") - - with col3: - generate_visual = st.checkbox("Generate visual flowchart", value=True) - - if question and submit_button: - with st.spinner("Generating answer..."): - # Perform Q&A with semantic image search - result = perform_qa_with_images(question, st.session_state.vector_store, k=k_value, learning_style=selected_style) - - if result: - # Display answer with better formatting - st.markdown("### Answer") - st.markdown(result["answer"]) - - # Display sources if available - if result["sources"]: - st.markdown("### Sources") - for source in result["sources"]: - st.markdown(f"- {source}") - - # Display semantically relevant images - if "relevant_images" in result and result["relevant_images"]: - st.markdown("### Relevant Images") - - # Create columns for images - img_cols = st.columns(min(len(result["relevant_images"]), 2)) - - for i, img in enumerate(result["relevant_images"]): - col_idx = i % len(img_cols) - with img_cols[col_idx]: - source = img.get("source", "Unknown source") - similarity = img.get("similarity", None) - - caption = f"{source} (Page {img['page']})" - if similarity: - caption += f" - Relevance: {similarity:.2f}" - - st.image( - f"data:image/{img['format']};base64,{img['data']}", - caption=caption, - use_column_width=True - ) - -# Use this updated code in the Summary tab display section -def display_summary_results(summary_result, learning_style): - """Display summary results with appropriate visual elements based on learning style""" - # Display summary with better formatting - st.markdown("### Document Summary") - st.markdown(summary_result["summary"]) - - # Display sources if available - if summary_result["sources"]: - st.markdown("### Sources") - for source in summary_result["sources"]: - st.markdown(f"- {source}") - - if learning_style == "Visual learner": - # Check if table data is available - if "table_data" in summary_result and summary_result["table_data"]: - print("### Key Topics Table") - # Code to display the table goes here - else: - print("No table found for this document. Please refer to other resources for visual aids.") - - - # Display flowchart if available - if "flowchart" in summary_result and summary_result["flowchart"]: - st.markdown("### Concept Relationship Map") - st.image(f"data:image/png;base64,{summary_result['flowchart']['visual']}", - caption=summary_result["flowchart"].get("title", "Document Concept Map"), - use_column_width=True) - -# Summarize Page -with tabs[2]: - st.title("đ Document Summarization") - - if not st.session_state.vector_store: - # Try to load from cache - st.session_state.vector_store = load_vector_store("docmind_store") - - if not st.session_state.vector_store: - st.warning("No documents have been processed. Please go to the Upload tab first.") - else: - st.markdown("Generate a comprehensive summary of all your uploaded documents.") - - # Learning style selector - summary_learning_styles = ["Standard", "Visual learner", "Auditory learner", "Reading/writing learner", "Kinesthetic learner"] - selected_summary_style = st.radio("Select your learning style for the summary:", summary_learning_styles, horizontal=True) - - summary_col1, summary_col2 = st.columns([1, 1]) - - with summary_col1: - summary_button = st.button("Generate Summary") - - with summary_col2: - summary_visual = st.checkbox("Include visual representation", value=True) - - if summary_button: - with st.spinner("Generating comprehensive summary..."): - # Generate summary - summary_result = generate_summary(st.session_state.vector_store, learning_style=selected_summary_style) - - if summary_result: - # Use the new display function - display_summary_results(summary_result, selected_summary_style) - else: - st.error("Failed to generate summary. Please try again.") -# Analysis Page -with tabs[3]: - st.title("đ Document Analysis") - - if not st.session_state.vector_store: - # Try to load from cache - st.session_state.vector_store = load_vector_store("docmind_store") - - if not st.session_state.vector_store: - st.warning("No documents have been processed. Please go to the Upload tab first.") - else: - st.markdown("Analyze your documents to extract key information and insights.") - - # Tabs for different analysis types - analysis_tabs = st.tabs(["đ Key Topics", "đ Sentiment Analysis","đ Custom Analysis"]) - - # Key Topics tab - with analysis_tabs[0]: - st.subheader("Key Topics & Keywords") - - num_keywords = st.slider("Number of keywords to extract", 5, 20, 10) - - extract_topics_button = st.button("Extract Key Topics") - - if extract_topics_button: - with st.spinner("Extracting key topics and keywords..."): - keywords = extract_keywords(st.session_state.vector_store, num_keywords=num_keywords) - - if keywords: - st.markdown("### Key Topics") - - keyword_cols = st.columns(3) - for i, keyword in enumerate(keywords): - col_idx = i % 3 - with keyword_cols[col_idx]: - st.markdown(f"- **{keyword}**") - - with st.spinner("Generating topic visualization..."): - import matplotlib.pyplot as plt - - keywords_to_plot = keywords[:min(10, len(keywords))] - - fig, ax = plt.subplots(figsize=(10, 6)) - - y_pos = np.arange(len(keywords_to_plot)) - importance = np.linspace(1, 0.4, len(keywords_to_plot)) - - ax.barh(y_pos, importance, align='center', color='skyblue') - ax.set_yticks(y_pos) - ax.set_yticklabels(keywords_to_plot) - ax.invert_yaxis() - ax.set_xlabel('Relative Importance') - ax.set_title('Key Topics by Relative Importance') - - buf = BytesIO() - plt.savefig(buf, format="png", bbox_inches="tight") - plt.close() - buf.seek(0) - - img_str = base64.b64encode(buf.read()).decode() - - st.image(f"data:image/png;base64,{img_str}", - caption="Key Topics Visualization", - use_column_width=True) - else: - st.error("Failed to extract keywords. Please try again.") - - # Sentiment Analysis tab - with analysis_tabs[1]: - st.subheader("Sentiment Analysis") - - sentiment_button = st.button("Analyze Sentiment") - - if sentiment_button: - with st.spinner("Analyzing document sentiment..."): - sentiment_result = analyze_sentiment(st.session_state.vector_store) - - if sentiment_result: - st.markdown("### Sentiment Analysis Results") - st.markdown(sentiment_result) - - # Try to extract the sentiment category for visualization - sentiment_categories = ["Strongly Positive", "Positive", "Neutral", "Negative", "Strongly Negative"] - detected_sentiment = None - - for category in sentiment_categories: - if category.lower() in sentiment_result.lower(): - detected_sentiment = category - break - - if detected_sentiment: - # Create a visual representation of sentiment - fig, ax = plt.subplots(figsize=(10, 4)) - - categories_pos = np.arange(len(sentiment_categories)) - sentiment_pos = sentiment_categories.index(detected_sentiment) - - # Create bars - bars = ax.bar(categories_pos, - [0.2 if i != sentiment_pos else 0.8 for i in range(len(sentiment_categories))], - color=['lightgray' if i != sentiment_pos else 'skyblue' for i in range(len(sentiment_categories))]) - - # Highlight the detected sentiment - bars[sentiment_pos].set_color('blue') - - ax.set_xticks(categories_pos) - ax.set_xticklabels(sentiment_categories) - ax.set_ylabel('Confidence') - ax.set_title('Document Sentiment Analysis') - - # Save to BytesIO - buf = BytesIO() - plt.savefig(buf, format="png", bbox_inches="tight") - plt.close() - buf.seek(0) - - # Convert to base64 - img_str = base64.b64encode(buf.read()).decode() - - # Display the chart - st.image(f"data:image/png;base64,{img_str}", - caption="Sentiment Visualization", - use_column_width=True) - else: - st.error("Failed to analyze sentiment. Please try again.") - # Custom Analysis tab -with analysis_tabs[2]: - st.subheader("Custom Document Analysis") - - custom_analysis_types = [ - "Content Structure Analysis", - "Main Arguments Extraction", - "Learning Objectives Identification", - "Technical Complexity Assessment", - "Key Definitions Extraction", - "Action Items Identification" - ] - - selected_analysis = st.selectbox("Select analysis type:", custom_analysis_types) - - custom_learning_styles = ["Standard", "Visual learner", "Auditory learner", "Reading/writing learner", "Kinesthetic learner"] - custom_selected_style = st.radio("Select your learning style for the analysis:", custom_learning_styles, horizontal=True) - - custom_analysis_button = st.button("Run Custom Analysis") - - if custom_analysis_button: - with st.spinner(f"Running {selected_analysis}..."): - analysis_result = perform_custom_analysis(selected_analysis, st.session_state.vector_store, learning_style=custom_selected_style) - if analysis_result: - display_analysis_results(analysis_result, selected_analysis, custom_selected_style) - else: - st.error(f"Failed to perform {selected_analysis}. Please try again.") - llm = initialize_llm() - if not llm: - st.error("Failed to initialize language model. Please check your API key.") - else: - # Get documents from vector store - docs = st.session_state.vector_store.similarity_search("", k=5) - - # Combine document contents - combined_text = "\n\n".join([doc.page_content for doc in docs]) - - learning_style_instructions = get_learning_style_prompt(custom_selected_style) - - # Create prompt based on selected analysis type - analysis_prompts = { - "Content Structure Analysis": f""" - Analyze the structure of the following content. Identify main sections, subsections, and how - information is organized. Evaluate the logical flow of information and make recommendations - for improved structure if applicable. - - CONTENT: - {{text}} - """, - - "Main Arguments Extraction": f""" - Extract the main arguments or key points from the following content. Identify the central thesis, - supporting arguments, evidence provided, and any counterarguments addressed. - - {learning_style_instructions} - - CONTENT: - {{text}} - """, - - "Learning Objectives Identification": f""" - Analyze the following content and identify what appear to be the main learning objectives. - What key knowledge, skills, or competencies would someone gain from this material? - - {learning_style_instructions} - - CONTENT: - {{text}} - """, - - "Technical Complexity Assessment": f""" - Assess the technical complexity of the following content. Identify specialized terminology, - complex concepts, prerequisites needed to understand this material, and the overall complexity level. - - {learning_style_instructions} - - CONTENT: - {{text}} - """, - - "Key Definitions Extraction": f""" - Extract key terms and their definitions from the following content. Identify important concepts, - technical terms, and provide clear definitions based on the context they're used in. - - {learning_style_instructions} - - CONTENT: - {{text}} - """, - - "Action Items Identification": f""" - Analyze the following content and extract any action items, tasks, or next steps mentioned. - Identify who should complete them (if specified), deadlines or timeframes, and priority levels. - - {learning_style_instructions} - - CONTENT: - {{text}} - """ - } - - selected_prompt = PromptTemplate( - input_variables=["text"], - template=analysis_prompts[selected_analysis] - ) - - chain = LLMChain(llm=llm, prompt=selected_prompt) - result = chain.run(text=combined_text) - - # Display results - st.markdown(f"### {selected_analysis} Results") - st.markdown(result) - -# CSS Customizations -st.markdown(""" - -""", unsafe_allow_html=True) - -# Add explanation for learning styles -with st.sidebar.expander("âšī¸ Learning Styles Explained", expanded=False): - st.markdown(""" - **Visual Learner**: - Prefers information presented with images, charts, and spatial relationships. - - **Auditory Learner**: - Learns best through listening, discussions, and verbal explanations. - - **Reading/Writing Learner**: - Prefers information displayed as words, lists, and text-based materials. - - **Kinesthetic Learner**: - Learns through doing, practical examples, and hands-on experiences. - """) - -# Add additional learning resources section -with st.sidebar.expander("đ Learning Resources", expanded=False): - st.markdown(""" - **For Visual Learners:** - - Use mind maps to connect ideas - - Convert text to diagrams - - Watch video explanations - - **For Auditory Learners:** - - Read content aloud - - Discuss concepts with others - - Listen to audio versions of materials - - **For Reading/Writing Learners:** - - Take detailed notes - - Rewrite key concepts in your own words - - Create word-based outlines - - **For Kinesthetic Learners:** - - Apply concepts to real problems - - Use physical objects to model ideas - - Take breaks to move around while studying - """) - -# Add a help section -with st.sidebar.expander("â How to Use", expanded=False): - st.markdown(""" - **1. Upload Documents** - - Upload your files (PDF, DOCX, TXT) - - Click "Process Documents" to extract content - - **2. Ask Questions** - - Type your question in the Q&A tab - - Select your preferred learning style - - Click "Submit Question" to get answers - - **3. Generate Summary** - - Go to the Summarize tab - - Select your preferred learning style - - Click "Generate Summary" - - **4. Analyze Content** - - Use the Analyze tab to extract key information - - Different analysis types are available - - Select your preferred learning style for results - """) -# Add a clear button to the sidebar -with st.sidebar.expander("đ Connect With Me", expanded=False): - st.markdown(""" -