Spaces:
Paused
Paused
| import streamlit as st | |
| import os | |
| import tempfile | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from document_classifier import DocumentClassifier | |
| import time | |
| from typing import List, Dict | |
| import json | |
| import requests | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="Document Classifier", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for better styling | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| font-size: 3rem; | |
| color: #1f77b4; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .sub-header { | |
| font-size: 1.5rem; | |
| color: #2c3e50; | |
| margin-top: 2rem; | |
| margin-bottom: 1rem; | |
| } | |
| .metric-card { | |
| background-color: #f8f9fa; | |
| padding: 1rem; | |
| border-radius: 0.5rem; | |
| border-left: 4px solid #1f77b4; | |
| } | |
| .success-message { | |
| background-color: #d4edda; | |
| color: #155724; | |
| padding: 1rem; | |
| border-radius: 0.5rem; | |
| border: 1px solid #c3e6cb; | |
| } | |
| .error-message { | |
| background-color: #f8d7da; | |
| color: #721c24; | |
| padding: 1rem; | |
| border-radius: 0.5rem; | |
| border: 1px solid #f5c6cb; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Initialize session state | |
| if 'classifier' not in st.session_state: | |
| st.session_state.classifier = None | |
| if 'classification_results' not in st.session_state: | |
| st.session_state.classification_results = [] | |
| if 'uploaded_files' not in st.session_state: | |
| st.session_state.uploaded_files = [] | |
| def initialize_classifier(): | |
| """Initialize the document classifier.""" | |
| if st.session_state.classifier is None: | |
| with st.spinner("Loading Hugging Face models..."): | |
| try: | |
| st.session_state.classifier = DocumentClassifier() | |
| st.success("β Document classifier initialized successfully!") | |
| return True | |
| except Exception as e: | |
| st.error(f"β Failed to initialize classifier: {str(e)}") | |
| return False | |
| return True | |
| def save_uploaded_file(uploaded_file) -> str: | |
| """Save uploaded file to temporary directory.""" | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp_file: | |
| tmp_file.write(uploaded_file.getbuffer()) | |
| return tmp_file.name | |
| except Exception as e: | |
| st.error(f"Error saving file: {str(e)}") | |
| return None | |
| def classify_single_file(file_path: str) -> Dict: | |
| """Classify a single file.""" | |
| if not st.session_state.classifier: | |
| return {"error": "Classifier not initialized", "success": False} | |
| try: | |
| result = st.session_state.classifier.classify_document(file_path) | |
| return result | |
| except Exception as e: | |
| return {"error": str(e), "success": False} | |
| def classify_multiple_files(file_paths: List[str]) -> List[Dict]: | |
| """Classify multiple files.""" | |
| if not st.session_state.classifier: | |
| return [{"error": "Classifier not initialized", "success": False}] | |
| try: | |
| results = st.session_state.classifier.classify_multiple_documents(file_paths) | |
| return results | |
| except Exception as e: | |
| return [{"error": str(e), "success": False}] | |
| def classify_with_api(file_path: str) -> Dict: | |
| """Call FastAPI classify endpoint with image file.""" | |
| api_url = "http://localhost:8000/classify" # Adjust if API runs elsewhere | |
| try: | |
| with open(file_path, "rb") as file_data: | |
| files = {"file": (os.path.basename(file_path), file_data)} | |
| response = requests.post(api_url, files=files) | |
| if response.status_code == 200: | |
| data = response.json() | |
| # Some keys might not be present in API/minimal - match display_classification_result | |
| data.setdefault('file_path', file_path) | |
| data.setdefault('file_name', os.path.basename(file_path)) | |
| data.setdefault('file_extension', os.path.splitext(file_path)[1].replace('.', '')) | |
| data.setdefault('content_length', 0) | |
| data.setdefault('text_preview', '') | |
| data.setdefault('file_type', os.path.splitext(file_path)[1].replace('.', '')) | |
| data.setdefault('all_scores', {}) | |
| return data | |
| else: | |
| return {"error": response.text, "success": False} | |
| except Exception as e: | |
| return {"error": str(e), "success": False} | |
| def display_classification_result(result: Dict): | |
| """Display a single classification result.""" | |
| if not result.get('success', False): | |
| st.error(f"β Classification failed: {result.get('error', 'Unknown error')}") | |
| return | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Document Type", result['file_type']) | |
| with col2: | |
| st.metric("Classification", result['classification'].title()) | |
| with col3: | |
| st.metric("Confidence", f"{result['confidence']:.2%}") | |
| # Display detailed information | |
| st.subheader("π Document Details") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write(f"**File Name:** {result['file_name']}") | |
| st.write(f"**File Extension:** {result['file_extension']}") | |
| st.write(f"**Content Length:** {result['content_length']} characters") | |
| with col2: | |
| st.write(f"**File Path:** {result['file_path']}") | |
| st.write(f"**Classification Confidence:** {result['confidence']:.2%}") | |
| # Display text preview | |
| if result['text_preview']: | |
| st.subheader("π Text Preview") | |
| st.text_area("Content Preview", result['text_preview'], height=150, disabled=True) | |
| # Display all classification scores | |
| st.subheader("π Classification Scores") | |
| scores_df = pd.DataFrame(list(result['all_scores'].items()), columns=['Document Type', 'Score']) | |
| scores_df['Score'] = scores_df['Score'].round(4) | |
| scores_df = scores_df.sort_values('Score', ascending=False) | |
| # Create a bar chart | |
| fig = px.bar(scores_df, x='Document Type', y='Score', | |
| title="Classification Confidence Scores", | |
| color='Score', | |
| color_continuous_scale='Blues') | |
| fig.update_layout(xaxis_tickangle=-45) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Display scores table | |
| st.dataframe(scores_df, use_container_width=True) | |
| def display_batch_results(results: List[Dict]): | |
| """Display batch classification results.""" | |
| if not results: | |
| st.warning("No results to display.") | |
| return | |
| # Summary statistics | |
| successful_results = [r for r in results if r.get('success', False)] | |
| failed_results = [r for r in results if not r.get('success', False)] | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("Total Files", len(results)) | |
| with col2: | |
| st.metric("Successful", len(successful_results)) | |
| with col3: | |
| st.metric("Failed", len(failed_results)) | |
| with col4: | |
| if successful_results: | |
| avg_confidence = sum(r['confidence'] for r in successful_results) / len(successful_results) | |
| st.metric("Avg Confidence", f"{avg_confidence:.2%}") | |
| # Classification distribution | |
| if successful_results: | |
| st.subheader("π Classification Distribution") | |
| classifications = [r['classification'] for r in successful_results] | |
| classification_counts = pd.Series(classifications).value_counts() | |
| fig = px.pie(values=classification_counts.values, | |
| names=classification_counts.index, | |
| title="Document Type Distribution") | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Detailed results table | |
| st.subheader("π Detailed Results") | |
| if successful_results: | |
| results_data = [] | |
| for result in successful_results: | |
| results_data.append({ | |
| 'File Name': result['file_name'], | |
| 'File Type': result['file_type'], | |
| 'Classification': result['classification'].title(), | |
| 'Confidence': f"{result['confidence']:.2%}", | |
| 'Content Length': result['content_length'] | |
| }) | |
| results_df = pd.DataFrame(results_data) | |
| st.dataframe(results_df, use_container_width=True) | |
| # Show failed results | |
| if failed_results: | |
| st.subheader("β Failed Classifications") | |
| for result in failed_results: | |
| st.error(f"**{result.get('file_name', 'Unknown')}**: {result.get('error', 'Unknown error')}") | |
| def main(): | |
| """Main Streamlit application.""" | |
| # Header | |
| st.markdown('<h1 class="main-header">π Document Classifier</h1>', unsafe_allow_html=True) | |
| st.markdown(""" | |
| <div style="text-align: center; margin-bottom: 2rem;"> | |
| <p style="font-size: 1.2rem; color: #666;"> | |
| Classify documents using Hugging Face models and content analysis | |
| </p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Sidebar | |
| st.sidebar.title("βοΈ Settings") | |
| # Initialize classifier | |
| if st.sidebar.button("π Initialize Classifier", type="primary"): | |
| initialize_classifier() | |
| # Model information | |
| st.sidebar.subheader("π€ Model Information") | |
| st.sidebar.info(""" | |
| **Models Used:** | |
| - Cardiff NLP Twitter RoBERTa Base Emotion | |
| - DistilBERT Base Uncased (fallback) | |
| **Supported Formats:** | |
| - PDF, DOCX, DOC | |
| - TXT, CSV | |
| - XLSX, XLS | |
| - Images (JPG, PNG, etc.) | |
| """) | |
| # Main content | |
| tab1, tab2, tab3 = st.tabs(["π Single File", "π Batch Upload", "π Results"]) | |
| with tab1: | |
| st.subheader("\U0001F4C1 Classify Single Document") | |
| sample_scan_folder = os.path.join(os.path.dirname(__file__), "sample_scans") | |
| sample_files = [f for f in os.listdir(sample_scan_folder) if f.lower().endswith((".png", ".jpg", ".jpeg"))] | |
| sample_files.sort() | |
| selected_sample = st.selectbox("Or select from example scans:", ["--- Select a sample ---"] + sample_files) | |
| if selected_sample != "--- Select a sample ---": | |
| sample_file_path = os.path.join(sample_scan_folder, selected_sample) | |
| st.image(sample_file_path, caption=f"Sample: {selected_sample}", use_column_width=True) | |
| if st.button("π Classify Sample Scan", key="classify_sample_scan"): | |
| with st.spinner("Calling API to classify sample scan..."): | |
| result = classify_with_api(sample_file_path) | |
| display_classification_result(result) | |
| uploaded_file = st.file_uploader( | |
| "Choose a document file", | |
| type=['pdf', 'docx', 'doc', 'txt', 'csv', 'xlsx', 'xls', 'jpg', 'jpeg', 'png'], | |
| help="Upload a document to classify its type and content" | |
| ) | |
| if uploaded_file is not None: | |
| if st.button("π Classify Document", type="primary"): | |
| if not initialize_classifier(): | |
| st.stop() | |
| # Save uploaded file | |
| file_path = save_uploaded_file(uploaded_file) | |
| if file_path: | |
| with st.spinner("Classifying document..."): | |
| result = classify_single_file(file_path) | |
| st.session_state.classification_results = [result] | |
| # Clean up temporary file | |
| try: | |
| os.unlink(file_path) | |
| except: | |
| pass | |
| # Display result | |
| display_classification_result(result) | |
| with tab2: | |
| st.subheader("π Batch Document Classification") | |
| uploaded_files = st.file_uploader( | |
| "Choose multiple document files", | |
| type=['pdf', 'docx', 'doc', 'txt', 'csv', 'xlsx', 'xls', 'jpg', 'jpeg', 'png'], | |
| accept_multiple_files=True, | |
| help="Upload multiple documents to classify them in batch" | |
| ) | |
| if uploaded_files: | |
| st.write(f"π {len(uploaded_files)} files selected") | |
| if st.button("π Classify All Documents", type="primary"): | |
| if not initialize_classifier(): | |
| st.stop() | |
| # Save uploaded files | |
| file_paths = [] | |
| for uploaded_file in uploaded_files: | |
| file_path = save_uploaded_file(uploaded_file) | |
| if file_path: | |
| file_paths.append(file_path) | |
| if file_paths: | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| results = [] | |
| for i, file_path in enumerate(file_paths): | |
| status_text.text(f"Processing file {i+1}/{len(file_paths)}: {os.path.basename(file_path)}") | |
| result = classify_single_file(file_path) | |
| results.append(result) | |
| progress_bar.progress((i + 1) / len(file_paths)) | |
| # Clean up temporary file | |
| try: | |
| os.unlink(file_path) | |
| except: | |
| pass | |
| st.session_state.classification_results = results | |
| status_text.text("β Classification complete!") | |
| # Display batch results | |
| display_batch_results(results) | |
| with tab3: | |
| st.subheader("π Classification Results") | |
| if st.session_state.classification_results: | |
| if len(st.session_state.classification_results) == 1: | |
| display_classification_result(st.session_state.classification_results[0]) | |
| else: | |
| display_batch_results(st.session_state.classification_results) | |
| else: | |
| st.info("π Upload and classify documents to see results here.") | |
| # Export results | |
| if st.session_state.classification_results: | |
| st.subheader("πΎ Export Results") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("π Export as CSV"): | |
| successful_results = [r for r in st.session_state.classification_results if r.get('success', False)] | |
| if successful_results: | |
| export_data = [] | |
| for result in successful_results: | |
| export_data.append({ | |
| 'File Name': result['file_name'], | |
| 'File Type': result['file_type'], | |
| 'Classification': result['classification'], | |
| 'Confidence': result['confidence'], | |
| 'Content Length': result['content_length'] | |
| }) | |
| df = pd.DataFrame(export_data) | |
| csv = df.to_csv(index=False) | |
| st.download_button( | |
| label="Download CSV", | |
| data=csv, | |
| file_name="classification_results.csv", | |
| mime="text/csv" | |
| ) | |
| with col2: | |
| if st.button("π Export as JSON"): | |
| json_data = json.dumps(st.session_state.classification_results, indent=2) | |
| st.download_button( | |
| label="Download JSON", | |
| data=json_data, | |
| file_name="classification_results.json", | |
| mime="application/json" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |