AIDocScanner / src /streamlit_app.py
GoldiSahoo's picture
Update src/streamlit_app.py
0a1cd54 verified
import streamlit as st
import os
import tempfile
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from document_classifier import DocumentClassifier
import time
from typing import List, Dict
import json
import requests
# Page configuration
st.set_page_config(
page_title="Document Classifier",
page_icon="πŸ“„",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for better styling
st.markdown("""
<style>
.main-header {
font-size: 3rem;
color: #1f77b4;
text-align: center;
margin-bottom: 2rem;
}
.sub-header {
font-size: 1.5rem;
color: #2c3e50;
margin-top: 2rem;
margin-bottom: 1rem;
}
.metric-card {
background-color: #f8f9fa;
padding: 1rem;
border-radius: 0.5rem;
border-left: 4px solid #1f77b4;
}
.success-message {
background-color: #d4edda;
color: #155724;
padding: 1rem;
border-radius: 0.5rem;
border: 1px solid #c3e6cb;
}
.error-message {
background-color: #f8d7da;
color: #721c24;
padding: 1rem;
border-radius: 0.5rem;
border: 1px solid #f5c6cb;
}
</style>
""", unsafe_allow_html=True)
# Initialize session state
if 'classifier' not in st.session_state:
st.session_state.classifier = None
if 'classification_results' not in st.session_state:
st.session_state.classification_results = []
if 'uploaded_files' not in st.session_state:
st.session_state.uploaded_files = []
def initialize_classifier():
"""Initialize the document classifier."""
if st.session_state.classifier is None:
with st.spinner("Loading Hugging Face models..."):
try:
st.session_state.classifier = DocumentClassifier()
st.success("βœ… Document classifier initialized successfully!")
return True
except Exception as e:
st.error(f"❌ Failed to initialize classifier: {str(e)}")
return False
return True
def save_uploaded_file(uploaded_file) -> str:
"""Save uploaded file to temporary directory."""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp_file:
tmp_file.write(uploaded_file.getbuffer())
return tmp_file.name
except Exception as e:
st.error(f"Error saving file: {str(e)}")
return None
def classify_single_file(file_path: str) -> Dict:
"""Classify a single file."""
if not st.session_state.classifier:
return {"error": "Classifier not initialized", "success": False}
try:
result = st.session_state.classifier.classify_document(file_path)
return result
except Exception as e:
return {"error": str(e), "success": False}
def classify_multiple_files(file_paths: List[str]) -> List[Dict]:
"""Classify multiple files."""
if not st.session_state.classifier:
return [{"error": "Classifier not initialized", "success": False}]
try:
results = st.session_state.classifier.classify_multiple_documents(file_paths)
return results
except Exception as e:
return [{"error": str(e), "success": False}]
def classify_with_api(file_path: str) -> Dict:
"""Call FastAPI classify endpoint with image file."""
api_url = "http://localhost:8000/classify" # Adjust if API runs elsewhere
try:
with open(file_path, "rb") as file_data:
files = {"file": (os.path.basename(file_path), file_data)}
response = requests.post(api_url, files=files)
if response.status_code == 200:
data = response.json()
# Some keys might not be present in API/minimal - match display_classification_result
data.setdefault('file_path', file_path)
data.setdefault('file_name', os.path.basename(file_path))
data.setdefault('file_extension', os.path.splitext(file_path)[1].replace('.', ''))
data.setdefault('content_length', 0)
data.setdefault('text_preview', '')
data.setdefault('file_type', os.path.splitext(file_path)[1].replace('.', ''))
data.setdefault('all_scores', {})
return data
else:
return {"error": response.text, "success": False}
except Exception as e:
return {"error": str(e), "success": False}
def display_classification_result(result: Dict):
"""Display a single classification result."""
if not result.get('success', False):
st.error(f"❌ Classification failed: {result.get('error', 'Unknown error')}")
return
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Document Type", result['file_type'])
with col2:
st.metric("Classification", result['classification'].title())
with col3:
st.metric("Confidence", f"{result['confidence']:.2%}")
# Display detailed information
st.subheader("πŸ“‹ Document Details")
col1, col2 = st.columns(2)
with col1:
st.write(f"**File Name:** {result['file_name']}")
st.write(f"**File Extension:** {result['file_extension']}")
st.write(f"**Content Length:** {result['content_length']} characters")
with col2:
st.write(f"**File Path:** {result['file_path']}")
st.write(f"**Classification Confidence:** {result['confidence']:.2%}")
# Display text preview
if result['text_preview']:
st.subheader("πŸ“– Text Preview")
st.text_area("Content Preview", result['text_preview'], height=150, disabled=True)
# Display all classification scores
st.subheader("πŸ“Š Classification Scores")
scores_df = pd.DataFrame(list(result['all_scores'].items()), columns=['Document Type', 'Score'])
scores_df['Score'] = scores_df['Score'].round(4)
scores_df = scores_df.sort_values('Score', ascending=False)
# Create a bar chart
fig = px.bar(scores_df, x='Document Type', y='Score',
title="Classification Confidence Scores",
color='Score',
color_continuous_scale='Blues')
fig.update_layout(xaxis_tickangle=-45)
st.plotly_chart(fig, use_container_width=True)
# Display scores table
st.dataframe(scores_df, use_container_width=True)
def display_batch_results(results: List[Dict]):
"""Display batch classification results."""
if not results:
st.warning("No results to display.")
return
# Summary statistics
successful_results = [r for r in results if r.get('success', False)]
failed_results = [r for r in results if not r.get('success', False)]
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Files", len(results))
with col2:
st.metric("Successful", len(successful_results))
with col3:
st.metric("Failed", len(failed_results))
with col4:
if successful_results:
avg_confidence = sum(r['confidence'] for r in successful_results) / len(successful_results)
st.metric("Avg Confidence", f"{avg_confidence:.2%}")
# Classification distribution
if successful_results:
st.subheader("πŸ“Š Classification Distribution")
classifications = [r['classification'] for r in successful_results]
classification_counts = pd.Series(classifications).value_counts()
fig = px.pie(values=classification_counts.values,
names=classification_counts.index,
title="Document Type Distribution")
st.plotly_chart(fig, use_container_width=True)
# Detailed results table
st.subheader("πŸ“‹ Detailed Results")
if successful_results:
results_data = []
for result in successful_results:
results_data.append({
'File Name': result['file_name'],
'File Type': result['file_type'],
'Classification': result['classification'].title(),
'Confidence': f"{result['confidence']:.2%}",
'Content Length': result['content_length']
})
results_df = pd.DataFrame(results_data)
st.dataframe(results_df, use_container_width=True)
# Show failed results
if failed_results:
st.subheader("❌ Failed Classifications")
for result in failed_results:
st.error(f"**{result.get('file_name', 'Unknown')}**: {result.get('error', 'Unknown error')}")
def main():
"""Main Streamlit application."""
# Header
st.markdown('<h1 class="main-header">πŸ“„ Document Classifier</h1>', unsafe_allow_html=True)
st.markdown("""
<div style="text-align: center; margin-bottom: 2rem;">
<p style="font-size: 1.2rem; color: #666;">
Classify documents using Hugging Face models and content analysis
</p>
</div>
""", unsafe_allow_html=True)
# Sidebar
st.sidebar.title("βš™οΈ Settings")
# Initialize classifier
if st.sidebar.button("πŸ”„ Initialize Classifier", type="primary"):
initialize_classifier()
# Model information
st.sidebar.subheader("πŸ€– Model Information")
st.sidebar.info("""
**Models Used:**
- Cardiff NLP Twitter RoBERTa Base Emotion
- DistilBERT Base Uncased (fallback)
**Supported Formats:**
- PDF, DOCX, DOC
- TXT, CSV
- XLSX, XLS
- Images (JPG, PNG, etc.)
""")
# Main content
tab1, tab2, tab3 = st.tabs(["πŸ“ Single File", "πŸ“‚ Batch Upload", "πŸ“Š Results"])
with tab1:
st.subheader("\U0001F4C1 Classify Single Document")
sample_scan_folder = os.path.join(os.path.dirname(__file__), "sample_scans")
sample_files = [f for f in os.listdir(sample_scan_folder) if f.lower().endswith((".png", ".jpg", ".jpeg"))]
sample_files.sort()
selected_sample = st.selectbox("Or select from example scans:", ["--- Select a sample ---"] + sample_files)
if selected_sample != "--- Select a sample ---":
sample_file_path = os.path.join(sample_scan_folder, selected_sample)
st.image(sample_file_path, caption=f"Sample: {selected_sample}", use_column_width=True)
if st.button("πŸ” Classify Sample Scan", key="classify_sample_scan"):
with st.spinner("Calling API to classify sample scan..."):
result = classify_with_api(sample_file_path)
display_classification_result(result)
uploaded_file = st.file_uploader(
"Choose a document file",
type=['pdf', 'docx', 'doc', 'txt', 'csv', 'xlsx', 'xls', 'jpg', 'jpeg', 'png'],
help="Upload a document to classify its type and content"
)
if uploaded_file is not None:
if st.button("πŸ” Classify Document", type="primary"):
if not initialize_classifier():
st.stop()
# Save uploaded file
file_path = save_uploaded_file(uploaded_file)
if file_path:
with st.spinner("Classifying document..."):
result = classify_single_file(file_path)
st.session_state.classification_results = [result]
# Clean up temporary file
try:
os.unlink(file_path)
except:
pass
# Display result
display_classification_result(result)
with tab2:
st.subheader("πŸ“‚ Batch Document Classification")
uploaded_files = st.file_uploader(
"Choose multiple document files",
type=['pdf', 'docx', 'doc', 'txt', 'csv', 'xlsx', 'xls', 'jpg', 'jpeg', 'png'],
accept_multiple_files=True,
help="Upload multiple documents to classify them in batch"
)
if uploaded_files:
st.write(f"πŸ“ {len(uploaded_files)} files selected")
if st.button("πŸ” Classify All Documents", type="primary"):
if not initialize_classifier():
st.stop()
# Save uploaded files
file_paths = []
for uploaded_file in uploaded_files:
file_path = save_uploaded_file(uploaded_file)
if file_path:
file_paths.append(file_path)
if file_paths:
progress_bar = st.progress(0)
status_text = st.empty()
results = []
for i, file_path in enumerate(file_paths):
status_text.text(f"Processing file {i+1}/{len(file_paths)}: {os.path.basename(file_path)}")
result = classify_single_file(file_path)
results.append(result)
progress_bar.progress((i + 1) / len(file_paths))
# Clean up temporary file
try:
os.unlink(file_path)
except:
pass
st.session_state.classification_results = results
status_text.text("βœ… Classification complete!")
# Display batch results
display_batch_results(results)
with tab3:
st.subheader("πŸ“Š Classification Results")
if st.session_state.classification_results:
if len(st.session_state.classification_results) == 1:
display_classification_result(st.session_state.classification_results[0])
else:
display_batch_results(st.session_state.classification_results)
else:
st.info("πŸ‘† Upload and classify documents to see results here.")
# Export results
if st.session_state.classification_results:
st.subheader("πŸ’Ύ Export Results")
col1, col2 = st.columns(2)
with col1:
if st.button("πŸ“„ Export as CSV"):
successful_results = [r for r in st.session_state.classification_results if r.get('success', False)]
if successful_results:
export_data = []
for result in successful_results:
export_data.append({
'File Name': result['file_name'],
'File Type': result['file_type'],
'Classification': result['classification'],
'Confidence': result['confidence'],
'Content Length': result['content_length']
})
df = pd.DataFrame(export_data)
csv = df.to_csv(index=False)
st.download_button(
label="Download CSV",
data=csv,
file_name="classification_results.csv",
mime="text/csv"
)
with col2:
if st.button("πŸ“‹ Export as JSON"):
json_data = json.dumps(st.session_state.classification_results, indent=2)
st.download_button(
label="Download JSON",
data=json_data,
file_name="classification_results.json",
mime="application/json"
)
if __name__ == "__main__":
main()