import streamlit as st import os import tempfile import pandas as pd import plotly.express as px import plotly.graph_objects as go from document_classifier import DocumentClassifier import time from typing import List, Dict import json import requests # Page configuration st.set_page_config( page_title="Document Classifier", page_icon="📄", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for better styling st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if 'classifier' not in st.session_state: st.session_state.classifier = None if 'classification_results' not in st.session_state: st.session_state.classification_results = [] if 'uploaded_files' not in st.session_state: st.session_state.uploaded_files = [] def initialize_classifier(): """Initialize the document classifier.""" if st.session_state.classifier is None: with st.spinner("Loading Hugging Face models..."): try: st.session_state.classifier = DocumentClassifier() st.success("✅ Document classifier initialized successfully!") return True except Exception as e: st.error(f"❌ Failed to initialize classifier: {str(e)}") return False return True def save_uploaded_file(uploaded_file) -> str: """Save uploaded file to temporary directory.""" try: with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp_file: tmp_file.write(uploaded_file.getbuffer()) return tmp_file.name except Exception as e: st.error(f"Error saving file: {str(e)}") return None def classify_single_file(file_path: str) -> Dict: """Classify a single file.""" if not st.session_state.classifier: return {"error": "Classifier not initialized", "success": False} try: result = st.session_state.classifier.classify_document(file_path) return result except Exception as e: return {"error": str(e), "success": False} def classify_multiple_files(file_paths: List[str]) -> List[Dict]: """Classify multiple files.""" if not st.session_state.classifier: return [{"error": "Classifier not initialized", "success": False}] try: results = st.session_state.classifier.classify_multiple_documents(file_paths) return results except Exception as e: return [{"error": str(e), "success": False}] def classify_with_api(file_path: str) -> Dict: """Call FastAPI classify endpoint with image file.""" api_url = "http://localhost:8000/classify" # Adjust if API runs elsewhere try: with open(file_path, "rb") as file_data: files = {"file": (os.path.basename(file_path), file_data)} response = requests.post(api_url, files=files) if response.status_code == 200: data = response.json() # Some keys might not be present in API/minimal - match display_classification_result data.setdefault('file_path', file_path) data.setdefault('file_name', os.path.basename(file_path)) data.setdefault('file_extension', os.path.splitext(file_path)[1].replace('.', '')) data.setdefault('content_length', 0) data.setdefault('text_preview', '') data.setdefault('file_type', os.path.splitext(file_path)[1].replace('.', '')) data.setdefault('all_scores', {}) return data else: return {"error": response.text, "success": False} except Exception as e: return {"error": str(e), "success": False} def display_classification_result(result: Dict): """Display a single classification result.""" if not result.get('success', False): st.error(f"❌ Classification failed: {result.get('error', 'Unknown error')}") return col1, col2, col3 = st.columns(3) with col1: st.metric("Document Type", result['file_type']) with col2: st.metric("Classification", result['classification'].title()) with col3: st.metric("Confidence", f"{result['confidence']:.2%}") # Display detailed information st.subheader("📋 Document Details") col1, col2 = st.columns(2) with col1: st.write(f"**File Name:** {result['file_name']}") st.write(f"**File Extension:** {result['file_extension']}") st.write(f"**Content Length:** {result['content_length']} characters") with col2: st.write(f"**File Path:** {result['file_path']}") st.write(f"**Classification Confidence:** {result['confidence']:.2%}") # Display text preview if result['text_preview']: st.subheader("📖 Text Preview") st.text_area("Content Preview", result['text_preview'], height=150, disabled=True) # Display all classification scores st.subheader("📊 Classification Scores") scores_df = pd.DataFrame(list(result['all_scores'].items()), columns=['Document Type', 'Score']) scores_df['Score'] = scores_df['Score'].round(4) scores_df = scores_df.sort_values('Score', ascending=False) # Create a bar chart fig = px.bar(scores_df, x='Document Type', y='Score', title="Classification Confidence Scores", color='Score', color_continuous_scale='Blues') fig.update_layout(xaxis_tickangle=-45) st.plotly_chart(fig, use_container_width=True) # Display scores table st.dataframe(scores_df, use_container_width=True) def display_batch_results(results: List[Dict]): """Display batch classification results.""" if not results: st.warning("No results to display.") return # Summary statistics successful_results = [r for r in results if r.get('success', False)] failed_results = [r for r in results if not r.get('success', False)] col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Files", len(results)) with col2: st.metric("Successful", len(successful_results)) with col3: st.metric("Failed", len(failed_results)) with col4: if successful_results: avg_confidence = sum(r['confidence'] for r in successful_results) / len(successful_results) st.metric("Avg Confidence", f"{avg_confidence:.2%}") # Classification distribution if successful_results: st.subheader("📊 Classification Distribution") classifications = [r['classification'] for r in successful_results] classification_counts = pd.Series(classifications).value_counts() fig = px.pie(values=classification_counts.values, names=classification_counts.index, title="Document Type Distribution") st.plotly_chart(fig, use_container_width=True) # Detailed results table st.subheader("📋 Detailed Results") if successful_results: results_data = [] for result in successful_results: results_data.append({ 'File Name': result['file_name'], 'File Type': result['file_type'], 'Classification': result['classification'].title(), 'Confidence': f"{result['confidence']:.2%}", 'Content Length': result['content_length'] }) results_df = pd.DataFrame(results_data) st.dataframe(results_df, use_container_width=True) # Show failed results if failed_results: st.subheader("❌ Failed Classifications") for result in failed_results: st.error(f"**{result.get('file_name', 'Unknown')}**: {result.get('error', 'Unknown error')}") def main(): """Main Streamlit application.""" # Header st.markdown('

📄 Document Classifier

', unsafe_allow_html=True) st.markdown("""

Classify documents using Hugging Face models and content analysis

""", unsafe_allow_html=True) # Sidebar st.sidebar.title("⚙️ Settings") # Initialize classifier if st.sidebar.button("🔄 Initialize Classifier", type="primary"): initialize_classifier() # Model information st.sidebar.subheader("🤖 Model Information") st.sidebar.info(""" **Models Used:** - Cardiff NLP Twitter RoBERTa Base Emotion - DistilBERT Base Uncased (fallback) **Supported Formats:** - PDF, DOCX, DOC - TXT, CSV - XLSX, XLS - Images (JPG, PNG, etc.) """) # Main content tab1, tab2, tab3 = st.tabs(["📁 Single File", "📂 Batch Upload", "📊 Results"]) with tab1: st.subheader("\U0001F4C1 Classify Single Document") sample_scan_folder = os.path.join(os.path.dirname(__file__), "sample_scans") sample_files = [f for f in os.listdir(sample_scan_folder) if f.lower().endswith((".png", ".jpg", ".jpeg"))] sample_files.sort() selected_sample = st.selectbox("Or select from example scans:", ["--- Select a sample ---"] + sample_files) if selected_sample != "--- Select a sample ---": sample_file_path = os.path.join(sample_scan_folder, selected_sample) st.image(sample_file_path, caption=f"Sample: {selected_sample}", use_column_width=True) if st.button("🔍 Classify Sample Scan", key="classify_sample_scan"): with st.spinner("Calling API to classify sample scan..."): result = classify_with_api(sample_file_path) display_classification_result(result) uploaded_file = st.file_uploader( "Choose a document file", type=['pdf', 'docx', 'doc', 'txt', 'csv', 'xlsx', 'xls', 'jpg', 'jpeg', 'png'], help="Upload a document to classify its type and content" ) if uploaded_file is not None: if st.button("🔍 Classify Document", type="primary"): if not initialize_classifier(): st.stop() # Save uploaded file file_path = save_uploaded_file(uploaded_file) if file_path: with st.spinner("Classifying document..."): result = classify_single_file(file_path) st.session_state.classification_results = [result] # Clean up temporary file try: os.unlink(file_path) except: pass # Display result display_classification_result(result) with tab2: st.subheader("📂 Batch Document Classification") uploaded_files = st.file_uploader( "Choose multiple document files", type=['pdf', 'docx', 'doc', 'txt', 'csv', 'xlsx', 'xls', 'jpg', 'jpeg', 'png'], accept_multiple_files=True, help="Upload multiple documents to classify them in batch" ) if uploaded_files: st.write(f"📁 {len(uploaded_files)} files selected") if st.button("🔍 Classify All Documents", type="primary"): if not initialize_classifier(): st.stop() # Save uploaded files file_paths = [] for uploaded_file in uploaded_files: file_path = save_uploaded_file(uploaded_file) if file_path: file_paths.append(file_path) if file_paths: progress_bar = st.progress(0) status_text = st.empty() results = [] for i, file_path in enumerate(file_paths): status_text.text(f"Processing file {i+1}/{len(file_paths)}: {os.path.basename(file_path)}") result = classify_single_file(file_path) results.append(result) progress_bar.progress((i + 1) / len(file_paths)) # Clean up temporary file try: os.unlink(file_path) except: pass st.session_state.classification_results = results status_text.text("✅ Classification complete!") # Display batch results display_batch_results(results) with tab3: st.subheader("📊 Classification Results") if st.session_state.classification_results: if len(st.session_state.classification_results) == 1: display_classification_result(st.session_state.classification_results[0]) else: display_batch_results(st.session_state.classification_results) else: st.info("👆 Upload and classify documents to see results here.") # Export results if st.session_state.classification_results: st.subheader("💾 Export Results") col1, col2 = st.columns(2) with col1: if st.button("📄 Export as CSV"): successful_results = [r for r in st.session_state.classification_results if r.get('success', False)] if successful_results: export_data = [] for result in successful_results: export_data.append({ 'File Name': result['file_name'], 'File Type': result['file_type'], 'Classification': result['classification'], 'Confidence': result['confidence'], 'Content Length': result['content_length'] }) df = pd.DataFrame(export_data) csv = df.to_csv(index=False) st.download_button( label="Download CSV", data=csv, file_name="classification_results.csv", mime="text/csv" ) with col2: if st.button("📋 Export as JSON"): json_data = json.dumps(st.session_state.classification_results, indent=2) st.download_button( label="Download JSON", data=json_data, file_name="classification_results.json", mime="application/json" ) if __name__ == "__main__": main()