doc_classify / streamlit_app.py
raahinaez's picture
Update streamlit_app.py
1338d4e verified
import streamlit as st
import fitz # PyMuPDF
import torch
import time
import torch.nn.functional as F
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification
)
# -------------------------------
# Page Configuration
# -------------------------------
st.set_page_config(
page_title="PDF Document Classification",
page_icon="πŸ“„",
layout="wide",
initial_sidebar_state="collapsed"
)
# -------------------------------
# Beautiful Blue Theme CSS
# -------------------------------
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;500;600;700;800&display=swap');
* {
font-family: 'Poppins', sans-serif;
}
.stApp {
background: linear-gradient(135deg, #667eea 0%, #764ba2 25%, #667eea 50%, #4facfe 75%, #00f2fe 100%);
background-size: 400% 400%;
animation: gradientShift 15s ease infinite;
}
@keyframes gradientShift {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
.main .block-container {
max-width: 1100px;
padding-top: 2rem;
padding-bottom: 3rem;
}
/* Header Styling */
h1 {
font-size: 3rem !important;
font-weight: 800 !important;
color: #ffffff !important;
text-align: center;
margin-bottom: 0.5rem !important;
text-shadow: 2px 2px 8px rgba(0,0,0,0.2);
letter-spacing: -0.02em;
}
h2 {
font-size: 1.75rem !important;
font-weight: 700 !important;
color: #1e3a8a !important;
margin-top: 2rem !important;
margin-bottom: 1rem !important;
}
h3 {
font-size: 1.25rem !important;
font-weight: 600 !important;
color: #1e40af !important;
margin-bottom: 0.75rem !important;
}
/* Card Styling */
.blue-card {
background: linear-gradient(135deg, #ffffff 0%, #f0f9ff 100%);
border-radius: 16px;
padding: 2rem;
box-shadow: 0 10px 30px rgba(30, 58, 138, 0.2);
border: 2px solid rgba(59, 130, 246, 0.3);
margin-bottom: 1.5rem;
backdrop-filter: blur(10px);
transition: all 0.3s ease;
}
.blue-card:hover {
transform: translateY(-5px);
box-shadow: 0 15px 40px rgba(30, 58, 138, 0.3);
}
/* Selectbox Styling */
.stSelectbox > div > div {
background: linear-gradient(135deg, #ffffff 0%, #f8fafc 100%) !important;
border: 2px solid #3b82f6 !important;
border-radius: 12px !important;
color: #1e3a8a !important;
}
.stSelectbox label {
font-weight: 600 !important;
color: #1e40af !important;
font-size: 0.9375rem !important;
margin-bottom: 0.75rem !important;
}
.stSelectbox [data-baseweb="select"] {
color: #1e3a8a !important;
font-weight: 500 !important;
}
/* File Uploader Styling */
.stFileUploader {
border: 3px dashed #3b82f6 !important;
border-radius: 16px !important;
padding: 2rem !important;
background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(147, 197, 253, 0.1) 100%) !important;
transition: all 0.3s ease !important;
}
.stFileUploader:hover {
border-color: #2563eb !important;
background: linear-gradient(135deg, rgba(59, 130, 246, 0.2) 0%, rgba(147, 197, 253, 0.2) 100%) !important;
transform: scale(1.02);
}
.stFileUploader label {
font-weight: 600 !important;
color: #1e40af !important;
font-size: 0.9375rem !important;
margin-bottom: 0.75rem !important;
}
/* Button Styling */
.stButton > button {
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 50%, #1d4ed8 100%) !important;
color: white !important;
font-weight: 700 !important;
font-size: 1rem !important;
padding: 1rem 2rem !important;
border-radius: 12px !important;
border: none !important;
width: 100% !important;
transition: all 0.3s ease !important;
box-shadow: 0 8px 20px rgba(59, 130, 246, 0.4) !important;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.stButton > button:hover {
background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 50%, #1e40af 100%) !important;
box-shadow: 0 12px 30px rgba(59, 130, 246, 0.6) !important;
transform: translateY(-3px) scale(1.02) !important;
}
/* Text Area Styling */
.stTextArea textarea {
font-family: 'Monaco', 'Menlo', monospace !important;
font-size: 0.875rem !important;
background: #ffffff !important;
border: 2px solid #3b82f6 !important;
border-radius: 12px !important;
color: #1e3a8a !important;
padding: 1rem !important;
}
.stTextArea textarea:disabled {
background: #f8fafc !important;
color: #1e3a8a !important;
-webkit-text-fill-color: #1e3a8a !important;
opacity: 1 !important;
}
.stTextArea textarea::placeholder {
color: #94a3b8 !important;
}
.stTextArea label {
font-weight: 600 !important;
color: #1e40af !important;
font-size: 0.9375rem !important;
}
/* Ensure textarea content is visible */
textarea[disabled] {
color: #1e3a8a !important;
-webkit-text-fill-color: #1e3a8a !important;
}
/* Force text color in textarea */
.stTextArea textarea,
.stTextArea textarea[disabled],
.stTextArea textarea:disabled {
color: #1e3a8a !important;
-webkit-text-fill-color: #1e3a8a !important;
}
/* Additional textarea visibility fixes */
textarea {
color: #1e3a8a !important;
}
textarea[disabled] {
color: #1e3a8a !important;
-webkit-text-fill-color: #1e3a8a !important;
opacity: 1 !important;
}
/* Target Streamlit's textarea wrapper */
[data-testid="stTextArea"] textarea,
[data-testid="stTextArea"] textarea[disabled] {
color: #1e3a8a !important;
-webkit-text-fill-color: #1e3a8a !important;
}
/* Ensure text content is visible */
.stTextArea textarea::value,
.stTextArea textarea::content {
color: #1e3a8a !important;
}
/* Metric Styling */
.stMetric {
background: linear-gradient(135deg, #ffffff 0%, #eff6ff 100%) !important;
padding: 1.5rem !important;
border-radius: 16px !important;
border: 2px solid rgba(59, 130, 246, 0.3) !important;
box-shadow: 0 8px 20px rgba(30, 58, 138, 0.15) !important;
}
.stMetric label {
font-weight: 600 !important;
color: #3b82f6 !important;
font-size: 0.8125rem !important;
text-transform: uppercase;
letter-spacing: 0.1em;
}
.stMetric [data-testid="stMetricValue"] {
font-weight: 800 !important;
color: #1e3a8a !important;
font-size: 2rem !important;
margin-top: 0.5rem !important;
}
/* Success/Error Messages */
.stSuccess {
background: linear-gradient(135deg, rgba(34, 197, 94, 0.15) 0%, rgba(16, 185, 129, 0.15) 100%) !important;
border-left: 5px solid #22c55e !important;
border-radius: 12px !important;
padding: 1.25rem !important;
color: #065f46 !important;
font-weight: 500 !important;
}
.stError {
background: linear-gradient(135deg, rgba(239, 68, 68, 0.15) 0%, rgba(220, 38, 38, 0.15) 100%) !important;
border-left: 5px solid #ef4444 !important;
border-radius: 12px !important;
padding: 1.25rem !important;
color: #991b1b !important;
font-weight: 500 !important;
}
/* Result Cards */
.result-card {
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 50%, #1d4ed8 100%);
color: white;
padding: 2rem;
border-radius: 16px;
text-align: center;
box-shadow: 0 10px 30px rgba(30, 58, 138, 0.4);
transition: all 0.3s ease;
border: 2px solid rgba(255, 255, 255, 0.2);
}
.result-card:hover {
transform: translateY(-5px) scale(1.05);
box-shadow: 0 15px 40px rgba(30, 58, 138, 0.5);
}
.result-card h4 {
font-size: 0.75rem;
font-weight: 700;
text-transform: uppercase;
letter-spacing: 0.15em;
opacity: 0.95;
margin-bottom: 0.75rem;
}
.result-card p {
font-size: 1.75rem;
font-weight: 800;
margin: 0;
text-shadow: 1px 1px 3px rgba(0,0,0,0.2);
}
/* Progress Bars */
.progress-wrapper {
margin-bottom: 1.5rem;
background: linear-gradient(135deg, #ffffff 0%, #f0f9ff 100%);
padding: 1.25rem;
border-radius: 12px;
border: 2px solid rgba(59, 130, 246, 0.2);
box-shadow: 0 4px 10px rgba(30, 58, 138, 0.1);
}
.progress-label {
display: flex;
justify-content: space-between;
margin-bottom: 0.75rem;
font-size: 0.9375rem;
font-weight: 600;
color: #1e40af;
}
.progress-bar {
height: 12px;
background: #e0e7ff;
border-radius: 8px;
overflow: hidden;
box-shadow: inset 0 2px 4px rgba(0,0,0,0.1);
}
.progress-fill {
height: 100%;
border-radius: 8px;
transition: width 0.5s ease;
box-shadow: 0 2px 8px rgba(59, 130, 246, 0.4);
}
/* Divider */
hr {
border: none;
border-top: 3px solid rgba(255, 255, 255, 0.3);
margin: 2.5rem 0;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
/* Spinner */
.stSpinner > div {
border-color: #3b82f6;
}
/* Text Colors - Ensure text in white containers is dark, not white */
.blue-card, .section-header, .progress-wrapper, .stMetric,
.blue-card p, .section-header p, .progress-wrapper p, .stMetric p,
.blue-card span, .section-header span, .progress-wrapper span, .stMetric span,
.blue-card div, .section-header div, .progress-wrapper div, .stMetric div,
.blue-card *, .section-header *, .progress-wrapper *, .stMetric * {
color: #1e3a8a !important;
}
/* Text on gradient background (not in white containers) should be white */
.main p:not(.blue-card p):not(.section-header p):not(.progress-wrapper p):not(.stMetric p),
.main span:not(.blue-card span):not(.section-header span):not(.progress-wrapper span):not(.stMetric span) {
color: #ffffff !important;
}
/* Caption text should be visible */
.main [data-testid="stCaption"] {
color: rgba(255, 255, 255, 0.8) !important;
}
/* Strong text on gradient background */
.main strong:not(.blue-card strong):not(.section-header strong):not(.progress-wrapper strong):not(.stMetric strong),
.main b:not(.blue-card b):not(.section-header b):not(.progress-wrapper b):not(.stMetric b) {
color: #ffffff !important;
font-weight: 700 !important;
text-shadow: 1px 1px 3px rgba(0,0,0,0.2);
}
/* Override for any nested elements in white containers */
.blue-card strong, .section-header strong, .progress-wrapper strong, .stMetric strong,
.blue-card b, .section-header b, .progress-wrapper b, .stMetric b {
color: #1e3a8a !important;
}
/* Section Headers */
.section-header {
background: linear-gradient(135deg, rgba(255, 255, 255, 0.95) 0%, rgba(240, 249, 255, 0.95) 100%);
padding: 1.5rem;
border-radius: 12px;
border: 2px solid rgba(59, 130, 246, 0.3);
margin-bottom: 1rem;
box-shadow: 0 4px 15px rgba(30, 58, 138, 0.15);
}
</style>
""", unsafe_allow_html=True)
# -------------------------------
# Header
# -------------------------------
st.markdown("""
<div style="text-align: center; margin-bottom: 3rem;">
<h1>πŸ“„ PDF Document Classification</h1>
<p style="color: rgba(255,255,255,0.95) !important; font-size: 1.125rem !important; font-weight: 400 !important; margin-top: 0.5rem; text-shadow: 1px 1px 4px rgba(0,0,0,0.2);">
<strong style="font-weight: 600;">AI-Powered Document Type Detection</strong> β€’ Upload a text-based PDF to classify it as Invoice, Contract, or Other
</p>
</div>
""", unsafe_allow_html=True)
# -------------------------------
# Model Configuration
# -------------------------------
MODEL_OPTIONS = {
"DistilBERT (distilbert-base-uncased)": "distilbert-base-uncased",
"TinyBERT (huawei-noah/TinyBERT_General_6L_768D)": "huawei-noah/TinyBERT_General_6L_768D"
}
LABELS = {
0: "Invoice",
1: "Contract",
2: "Other"
}
NUM_LABELS = len(LABELS)
# -------------------------------
# Load Model & Tokenizer
# -------------------------------
@st.cache_resource
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=NUM_LABELS
)
model.eval()
return tokenizer, model
# -------------------------------
# PDF Text Extraction
# -------------------------------
def extract_text_from_pdf(pdf_file):
doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
text = ""
for page in doc:
text += page.get_text("text")
return text.strip()
# -------------------------------
# Main UI
# -------------------------------
col1, col2 = st.columns(2, gap="large")
with col1:
st.markdown('<div class="section-header"><h3 style="margin: 0; color: #1e40af !important;">πŸ€– Model Selection</h3></div>', unsafe_allow_html=True)
selected_model_label = st.selectbox(
"Choose your AI model",
list(MODEL_OPTIONS.keys()),
label_visibility="visible"
)
selected_model_name = MODEL_OPTIONS[selected_model_label]
tokenizer, model = load_model(selected_model_name)
with col2:
st.markdown('<div class="section-header"><h3 style="margin: 0; color: #1e40af !important;">πŸ“Ž Document Upload</h3></div>', unsafe_allow_html=True)
uploaded_file = st.file_uploader(
"Upload your PDF file",
type=["pdf"],
label_visibility="visible",
help="Select a text-based PDF file to classify"
)
# -------------------------------
# Processing
# -------------------------------
if uploaded_file:
st.markdown("---")
with st.spinner("πŸ” Extracting text from PDF..."):
pdf_text = extract_text_from_pdf(uploaded_file)
if not pdf_text:
st.error("**❌ No Text Found**\n\nThis PDF appears to be image-based. Please upload a text-based PDF file.")
else:
st.success("**βœ… Text Extracted Successfully!**\n\nReady for classification")
st.markdown('<div class="section-header"><h3 style="margin: 0; color: #1e40af !important;">πŸ“ Text Preview</h3></div>', unsafe_allow_html=True)
# Display text in a styled container for better visibility
preview_text = pdf_text[:2000] if len(pdf_text) > 2000 else pdf_text
st.markdown(f"""
<div style="background: #ffffff; border: 2px solid #3b82f6; border-radius: 12px; padding: 1.5rem; max-height: 300px; overflow-y: auto; font-family: 'Monaco', 'Menlo', monospace; font-size: 0.875rem; color: #1e3a8a; line-height: 1.6;">
{preview_text.replace(chr(10), '<br>').replace(chr(13), '')}
</div>
""", unsafe_allow_html=True)
st.caption(f"Showing first 2000 characters of {len(pdf_text)} total characters")
st.markdown("---")
col_btn1, col_btn2, col_btn3 = st.columns([1, 2, 1])
with col_btn2:
classify_btn = st.button("πŸš€ Classify Document", use_container_width=True)
if classify_btn:
with st.spinner("πŸ€– Running AI inference..."):
start_time = time.time()
inputs = tokenizer(
pdf_text[:512],
return_tensors="pt",
truncation=True,
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
inference_time = time.time() - start_time
logits = outputs.logits
probs = F.softmax(logits, dim=1)
confidence, predicted_class = torch.max(probs, dim=1)
# Results
st.markdown("---")
st.markdown('<div class="section-header"><h2 style="margin: 0; text-align: center;">🎯 Classification Results</h2></div>', unsafe_allow_html=True)
doc_type = LABELS[predicted_class.item()]
conf_percent = confidence.item() * 100
# Result cards
col1, col2, col3 = st.columns(3, gap="medium")
with col1:
st.markdown(f"""
<div class="result-card">
<h4>Model</h4>
<p style="font-size: 1.25rem;">{selected_model_label.split('(')[0].strip()}</p>
</div>
""", unsafe_allow_html=True)
with col2:
st.markdown(f"""
<div class="result-card">
<h4>Document Type</h4>
<p>{doc_type}</p>
</div>
""", unsafe_allow_html=True)
with col3:
if conf_percent >= 80:
conf_gradient = "linear-gradient(135deg, #22c55e 0%, #16a34a 100%)"
elif conf_percent >= 60:
conf_gradient = "linear-gradient(135deg, #f59e0b 0%, #d97706 100%)"
else:
conf_gradient = "linear-gradient(135deg, #ef4444 0%, #dc2626 100%)"
st.markdown(f"""
<div class="result-card" style="background: {conf_gradient};">
<h4>Confidence</h4>
<p>{conf_percent:.1f}%</p>
</div>
""", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
# Inference time
col_time1, col_time2, col_time3 = st.columns([1, 2, 1])
with col_time2:
st.metric("⚑ Inference Time", f"{inference_time:.3f} seconds")
# Confidence breakdown
st.markdown('<div class="section-header"><h2 style="margin: 0; text-align: center;">πŸ“ˆ Confidence Breakdown</h2></div>', unsafe_allow_html=True)
prob_dict = {LABELS[i]: probs[0][i].item() * 100 for i in range(len(LABELS))}
sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True)
for label, prob in sorted_probs:
if label == doc_type:
bar_gradient = "linear-gradient(90deg, #3b82f6 0%, #2563eb 100%)"
else:
bar_gradient = "linear-gradient(90deg, #93c5fd 0%, #60a5fa 100%)"
st.markdown(f"""
<div class="progress-wrapper">
<div class="progress-label">
<span>{label}</span>
<span style="font-weight: 700; color: #1e3a8a;">{prob:.1f}%</span>
</div>
<div class="progress-bar">
<div class="progress-fill" style="background: {bar_gradient}; width: {prob}%;"></div>
</div>
</div>
""", unsafe_allow_html=True)