cryogenic22's picture
Update app.py
92daf2c verified
from __future__ import annotations
import streamlit as st
import os
import json
import time
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime
from dataclasses import dataclass
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain_community.document_loaders import PyPDFLoader
import tempfile
from utils.database import (
create_connection,
create_tables,
create_chat_tables,
get_all_documents,
get_collections,
get_collection_documents,
get_embeddings_model,
verify_database_tables,
create_collection,
add_document_to_collection,
get_recent_documents,
save_chat_message,
create_new_chat,
get_chat_messages,
get_document_tags,
add_document_tags,
delete_collection)
from utils.ai_utils import generate_document_tags, initialize_chat_system
@dataclass
class SessionState:
"""Default values for session state variables."""
show_collection_dialog: bool = False
selected_collection: Optional[Dict] = None
chat_ready: bool = False
messages: Optional[List] = None
current_chat_id: Optional[int] = None
vector_store: Optional[Any] = None
qa_system: Optional[Any] = None
reinitialize_chat: bool = False
def initialize_session_state():
"""Initialize session state with default values."""
defaults = SessionState()
if 'initialized' not in st.session_state:
# Setup data paths
data_path = Path('/data' if os.path.exists('/data') else 'data')
vector_store_path = data_path / 'vector_stores'
# Create necessary directories
data_path.mkdir(parents=True, exist_ok=True)
vector_store_path.mkdir(parents=True, exist_ok=True)
# Initialize session state
st.session_state.update({
'show_collection_dialog': defaults.show_collection_dialog,
'selected_collection': defaults.selected_collection,
'chat_ready': defaults.chat_ready,
'messages': [] if defaults.messages is None else defaults.messages,
'current_chat_id': defaults.current_chat_id,
'vector_store': defaults.vector_store,
'qa_system': defaults.qa_system,
'reinitialize_chat': defaults.reinitialize_chat,
'initialized': True,
'data_path': data_path,
'vector_store_path': vector_store_path,
'show_explorer': False
})
def generate_document_tags(content: str) -> List[str]:
"""Generate tags for a document using AI."""
try:
llm = ChatOpenAI(temperature=0.2, model="gpt-3.5-turbo")
prompt = """Analyze the following document content and generate relevant tags/keywords.
Focus on key themes, topics, and important terminology.
Return only the tags as a comma-separated list.
Content: {content}"""
response = llm.invoke(prompt.format(content=content[:2000])) # Use first 2000 chars
tags = [tag.strip() for tag in response.split(',')]
return tags
except Exception as e:
st.error(f"Error generating tags: {e}")
return []
def process_document(file_path: str, collection_id: Optional[int] = None) -> Tuple[List, str, List[str]]:
"""Process a document with automatic tagging."""
try:
# Load PDF
loader = PyPDFLoader(file_path)
documents = loader.load()
# Extract full content
full_content = "\n".join(doc.page_content for doc in documents)
# Generate tags
tags = generate_document_tags(full_content)
# Create text splitter for chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
separators=["\n\n", "\n", " ", ""]
)
# Split documents into chunks
chunks = text_splitter.split_documents(documents)
# Add metadata to chunks
for chunk in chunks:
chunk.metadata.update({
'collection_id': collection_id,
'tags': tags
})
return chunks, full_content, tags
except Exception as e:
st.error(f"Error processing document: {e}")
return [], "", []
def handle_document_upload(uploaded_files: List, collection_id: Optional[int] = None) -> bool:
"""Handle document upload with progress tracking and auto-tagging."""
try:
progress_container = st.empty()
status_container = st.empty()
progress_bar = progress_container.progress(0)
# Initialize embeddings
embeddings = get_embeddings_model()
if not embeddings:
status_container.error("Failed to initialize embeddings model")
return False
progress_bar.progress(10)
all_chunks = []
documents = []
# Process each document
progress_per_file = 70 / len(uploaded_files)
current_progress = 10
for idx, uploaded_file in enumerate(uploaded_files):
status_container.info(f"Processing {uploaded_file.name}...")
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
tmp_file.write(uploaded_file.getvalue())
tmp_file.flush()
# Process document with tagging
chunks, content, tags = process_document(tmp_file.name, collection_id)
# Store in database
doc_id = insert_document(st.session_state.db_conn, uploaded_file.name, content)
if not doc_id:
status_container.error(f"Failed to store document: {uploaded_file.name}")
continue
# Add tags
if tags:
add_document_tags(st.session_state.db_conn, doc_id, tags)
# Add to collection if specified
if collection_id:
add_document_to_collection(st.session_state.db_conn, doc_id, collection_id)
all_chunks.extend(chunks)
documents.append(content)
current_progress += progress_per_file
progress_bar.progress(int(current_progress))
# Initialize vector store
status_container.info("Creating document index...")
vector_store = FAISS.from_documents(all_chunks, embeddings)
st.session_state.vector_store = vector_store
st.session_state.qa_system = initialize_qa_system(vector_store)
st.session_state.chat_ready = True
progress_bar.progress(100)
status_container.success("Documents processed successfully!")
# Clean up progress display
time.sleep(2)
progress_container.empty()
status_container.empty()
return True
except Exception as e:
st.error(f"Error uploading documents: {e}")
return False
def display_header():
"""Display the application header with navigation."""
# Add custom CSS for header styling
st.markdown(
"""
<style>
.stButton > button {
width: 100%;
margin-bottom: 0;
}
.header-button {
margin: 0 5px;
}
</style>
""",
unsafe_allow_html=True
)
# Create header layout
header_container = st.container()
with header_container:
# Main header row
col1, col2, col3, col4, col5, col6 = st.columns([1.5, 2.5, 1, 1, 1, 1])
# Logo
with col1:
if os.path.exists("img/logo.png"):
st.image("img/logo.png", width=150)
else:
st.info("Logo missing: img/logo.png")
# Title
with col2:
st.markdown("##### Synaptyx RFP Analyzer Agent")
# Navigation Buttons
with col3:
if st.button("๐Ÿ  Home", use_container_width=True, key="home_btn"):
st.session_state.chat_ready = False
st.session_state.messages = []
st.session_state.current_chat_id = None
st.session_state.show_explorer = False
st.rerun()
with col4:
if st.button("๐Ÿ“š Explorer", use_container_width=True, key="explorer_btn"):
st.session_state.show_explorer = True
st.session_state.chat_ready = False
st.rerun()
with col5:
if st.session_state.chat_ready:
if st.button("๐Ÿ’ญ New Chat", use_container_width=True, key="chat_btn"):
st.session_state.messages = []
st.session_state.current_chat_id = None
st.rerun()
with col6:
if st.button("๐Ÿ“ Upload", use_container_width=True, key="upload_btn"):
st.session_state.show_collection_dialog = True
st.rerun()
# Add divider after header
st.divider()
def display_collection_management():
"""Display collection management interface."""
st.header("๐Ÿ“ Collection Management")
col1, col2 = st.columns([2, 1])
with col1:
# Create new collection form
with st.form("create_collection_form"):
st.subheader("Create New Collection")
name = st.text_input("Collection Name")
description = st.text_area("Description")
submit = st.form_submit_button("Create Collection", use_container_width=True)
if submit and name:
collection_id = create_collection(st.session_state.db_conn, name, description)
if collection_id:
st.success(f"Collection '{name}' created successfully!")
st.session_state.current_collection_id = collection_id
st.rerun()
# Display existing collections
collections = get_collections(st.session_state.db_conn)
if collections:
st.markdown("### Existing Collections")
for collection in collections:
with st.expander(f"๐Ÿ“ {collection['name']} ({collection['doc_count']} documents)"):
col1, col2 = st.columns([3, 1])
with col1:
st.write(f"**Description:** {collection.get('description', 'No description')}")
st.write(f"**Created:** {collection['created_at']}")
# Display documents in collection
docs = get_collection_documents(st.session_state.db_conn, collection['id'])
if docs:
st.write("**Documents:**")
for doc in docs:
st.write(f"- {doc['name']}")
tags = get_document_tags(st.session_state.db_conn, doc['id'])
if tags:
st.write(f" Tags: {', '.join(tags)}")
with col2:
# Add documents to collection
uploaded_files = st.file_uploader(
"Add Documents",
type=['pdf'],
accept_multiple_files=True,
key=f"collection_upload_{collection['id']}"
)
if uploaded_files:
if handle_document_upload(uploaded_files, collection_id=collection['id']):
st.success("Documents added successfully!")
st.rerun()
if st.button("Start Chat", key=f"chat_{collection['id']}", use_container_width=True):
st.session_state.selected_collection = collection
initialize_chat_system(collection['id'])
st.rerun()
if st.button("Delete Collection", key=f"delete_{collection['id']}", use_container_width=True):
if st.warning("Are you sure you want to delete this collection?"):
if delete_collection(st.session_state.db_conn, collection['id']):
st.success("Collection deleted successfully!")
st.rerun()
def display_chat_interface():
"""Display the main chat interface with persistent storage."""
st.header("๐Ÿ’ฌ Ask your documents")
# Create new chat if needed
if not st.session_state.current_chat_id:
chat_title = f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}"
collection_id = st.session_state.selected_collection['id'] if st.session_state.selected_collection else None
st.session_state.current_chat_id = create_new_chat(st.session_state.db_conn, chat_title, collection_id)
# Display chat messages
for message in st.session_state.messages:
with st.chat_message("user" if isinstance(message, HumanMessage) else "assistant"):
st.markdown(message.content)
# Chat input
if prompt := st.chat_input("Ask a question about your documents..."):
st.session_state.messages.append(HumanMessage(content=prompt))
with st.spinner("Analyzing your documents..."):
response = st.session_state.qa_system.invoke({
"input": prompt,
"chat_history": st.session_state.messages
})
# Save messages to database
save_chat_message(
st.session_state.db_conn,
st.session_state.current_chat_id,
"human",
prompt
)
save_chat_message(
st.session_state.db_conn,
st.session_state.current_chat_id,
"assistant",
response.content
)
st.session_state.messages.append(AIMessage(content=response.content))
st.rerun()
def display_welcome_screen():
"""Display welcome screen with quick actions."""
st.header("Quick Start")
col1, col2 = st.columns([3, 2])
with col1:
# Upload new documents
st.markdown("### Upload Documents")
collection_id = None
collections = get_collections(st.session_state.db_conn)
if collections:
selected_collection = st.selectbox(
"Select Collection (Optional)",
options=[("None", None)] + [(c["name"], c["id"]) for c in collections],
format_func=lambda x: x[0]
)
collection_id = selected_collection[1] if selected_collection[0] != "None" else None
# Add new collection button
if st.button("Create New Collection", use_container_width=True):
st.session_state.show_collection_dialog = True
st.rerun()
uploaded_files = st.file_uploader(
"Upload Documents",
type=['pdf'],
accept_multiple_files=True,
help="Upload PDF documents to start analyzing"
)
if uploaded_files:
with st.spinner("Processing documents..."):
if handle_document_upload(uploaded_files, collection_id=collection_id):
initialize_chat_system(collection_id)
st.rerun()
with col2:
# Display existing collections
st.header("Collections")
if collections:
for collection in collections:
with st.expander(f"๐Ÿ“ {collection['name']} ({collection['doc_count']} documents)"):
st.write(collection.get('description', ''))
if st.button("Start Chat", key=f"chat_{collection['id']}", use_container_width=True):
st.session_state.selected_collection = collection
if initialize_chat_system(collection['id']):
st.rerun()
# Show recent documents
st.header("Recent Documents")
recent_docs = get_recent_documents(st.session_state.db_conn, limit=5)
for doc in recent_docs:
with st.expander(f"๐Ÿ“„ {doc['name']}"):
st.caption(f"Upload date: {doc['upload_date']}")
if doc['collections']:
st.caption(f"Collections: {', '.join(doc['collections'])}")
if st.button("Start Chat", key=f"doc_{doc['id']}", use_container_width=True):
if initialize_chat_system():
st.rerun()
def display_document_chunks():
"""Display document chunks with search and filtering capabilities."""
st.subheader("Document Chunk Explorer")
# Get all documents
documents = get_all_documents(st.session_state.db_conn)
if not documents:
st.info("No documents available.")
return
# Document selection
selected_doc = st.selectbox(
"Select Document",
options=documents,
format_func=lambda x: x['name']
)
if not selected_doc:
return
try:
# Load vector store for selected document
embeddings = get_embeddings_model()
chunks = []
# Search functionality
search_query = st.text_input("๐Ÿ” Search within chunks")
if search_query and st.session_state.vector_store:
chunks = st.session_state.vector_store.similarity_search(search_query, k=5)
elif st.session_state.vector_store:
chunks = st.session_state.vector_store.similarity_search("", k=100)
# Display chunks with metadata
st.markdown("### Document Chunks")
# Filtering options
col1, col2 = st.columns(2)
with col1:
chunk_size = st.slider("Preview Size", 100, 1000, 500)
with col2:
sort_by = st.selectbox("Sort By", ["Relevance", "Position"])
# Display chunks in an organized way
for i, chunk in enumerate(chunks):
with st.expander(f"Chunk {i+1} | Source: {chunk.metadata.get('source', 'Unknown')}"):
# Content preview
st.markdown("**Content:**")
st.text(chunk.page_content[:chunk_size] + "..." if len(chunk.page_content) > chunk_size else chunk.page_content)
# Metadata
st.markdown("**Metadata:**")
for key, value in chunk.metadata.items():
st.text(f"{key}: {value}")
# Actions
col1, col2 = st.columns(2)
with col1:
if st.button("Copy", key=f"copy_{i}"):
st.write("Content copied to clipboard!")
with col2:
if st.button("Start Chat", key=f"chat_{i}"):
initialize_chat_system()
st.session_state.messages.append(
HumanMessage(content=f"Tell me about: {chunk.page_content[:100]}...")
)
st.rerun()
except Exception as e:
st.error(f"Error loading document chunks: {e}")
def main():
"""Main application function with improved state management."""
st.set_page_config(
page_title="Synaptyx RFP Analyzer Agent",
layout="wide",
initial_sidebar_state="collapsed"
)
# Initialize session state with paths
initialize_session_state()
# Initialize database connection
if 'db_conn' not in st.session_state:
db_path = st.session_state.data_path / 'analysis.db'
st.session_state.db_conn = create_connection(str(db_path))
create_tables(st.session_state.db_conn)
create_chat_tables(st.session_state.db_conn)
verify_database_tables(st.session_state.db_conn)
# Display header
display_header()
# Show different views based on application state
if st.session_state.show_collection_dialog:
display_collection_management()
elif st.session_state.chat_ready:
display_chat_interface()
elif st.session_state.show_explorer:
display_document_chunks()
else:
display_welcome_screen()
if __name__ == "__main__":
main()