import streamlit as st
import fitz # PyMuPDF
import torch
import time
import torch.nn.functional as F
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification
)
# -------------------------------
# Page Configuration
# -------------------------------
st.set_page_config(
page_title="PDF Document Classification",
page_icon="📄",
layout="wide",
initial_sidebar_state="collapsed"
)
# -------------------------------
# Beautiful Blue Theme CSS
# -------------------------------
st.markdown("""
""", unsafe_allow_html=True)
# -------------------------------
# Header
# -------------------------------
st.markdown("""
📄 PDF Document Classification
AI-Powered Document Type Detection • Upload a text-based PDF to classify it as Invoice, Contract, or Other
""", unsafe_allow_html=True)
# -------------------------------
# Model Configuration
# -------------------------------
MODEL_OPTIONS = {
"DistilBERT (distilbert-base-uncased)": "distilbert-base-uncased",
"TinyBERT (huawei-noah/TinyBERT_General_6L_768D)": "huawei-noah/TinyBERT_General_6L_768D"
}
LABELS = {
0: "Invoice",
1: "Contract",
2: "Other"
}
NUM_LABELS = len(LABELS)
# -------------------------------
# Load Model & Tokenizer
# -------------------------------
@st.cache_resource
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=NUM_LABELS
)
model.eval()
return tokenizer, model
# -------------------------------
# PDF Text Extraction
# -------------------------------
def extract_text_from_pdf(pdf_file):
doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
text = ""
for page in doc:
text += page.get_text("text")
return text.strip()
# -------------------------------
# Main UI
# -------------------------------
col1, col2 = st.columns(2, gap="large")
with col1:
st.markdown('', unsafe_allow_html=True)
selected_model_label = st.selectbox(
"Choose your AI model",
list(MODEL_OPTIONS.keys()),
label_visibility="visible"
)
selected_model_name = MODEL_OPTIONS[selected_model_label]
tokenizer, model = load_model(selected_model_name)
with col2:
st.markdown('', unsafe_allow_html=True)
uploaded_file = st.file_uploader(
"Upload your PDF file",
type=["pdf"],
label_visibility="visible",
help="Select a text-based PDF file to classify"
)
# -------------------------------
# Processing
# -------------------------------
if uploaded_file:
st.markdown("---")
with st.spinner("🔍 Extracting text from PDF..."):
pdf_text = extract_text_from_pdf(uploaded_file)
if not pdf_text:
st.error("**❌ No Text Found**\n\nThis PDF appears to be image-based. Please upload a text-based PDF file.")
else:
st.success("**✅ Text Extracted Successfully!**\n\nReady for classification")
st.markdown('', unsafe_allow_html=True)
# Display text in a styled container for better visibility
preview_text = pdf_text[:2000] if len(pdf_text) > 2000 else pdf_text
st.markdown(f"""
{preview_text.replace(chr(10), '
').replace(chr(13), '')}
""", unsafe_allow_html=True)
st.caption(f"Showing first 2000 characters of {len(pdf_text)} total characters")
st.markdown("---")
col_btn1, col_btn2, col_btn3 = st.columns([1, 2, 1])
with col_btn2:
classify_btn = st.button("🚀 Classify Document", use_container_width=True)
if classify_btn:
with st.spinner("🤖 Running AI inference..."):
start_time = time.time()
inputs = tokenizer(
pdf_text[:512],
return_tensors="pt",
truncation=True,
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
inference_time = time.time() - start_time
logits = outputs.logits
probs = F.softmax(logits, dim=1)
confidence, predicted_class = torch.max(probs, dim=1)
# Results
st.markdown("---")
st.markdown('', unsafe_allow_html=True)
doc_type = LABELS[predicted_class.item()]
conf_percent = confidence.item() * 100
# Result cards
col1, col2, col3 = st.columns(3, gap="medium")
with col1:
st.markdown(f"""
Model
{selected_model_label.split('(')[0].strip()}
""", unsafe_allow_html=True)
with col2:
st.markdown(f"""
""", unsafe_allow_html=True)
with col3:
if conf_percent >= 80:
conf_gradient = "linear-gradient(135deg, #22c55e 0%, #16a34a 100%)"
elif conf_percent >= 60:
conf_gradient = "linear-gradient(135deg, #f59e0b 0%, #d97706 100%)"
else:
conf_gradient = "linear-gradient(135deg, #ef4444 0%, #dc2626 100%)"
st.markdown(f"""
Confidence
{conf_percent:.1f}%
""", unsafe_allow_html=True)
st.markdown("
", unsafe_allow_html=True)
# Inference time
col_time1, col_time2, col_time3 = st.columns([1, 2, 1])
with col_time2:
st.metric("⚡ Inference Time", f"{inference_time:.3f} seconds")
# Confidence breakdown
st.markdown('', unsafe_allow_html=True)
prob_dict = {LABELS[i]: probs[0][i].item() * 100 for i in range(len(LABELS))}
sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True)
for label, prob in sorted_probs:
if label == doc_type:
bar_gradient = "linear-gradient(90deg, #3b82f6 0%, #2563eb 100%)"
else:
bar_gradient = "linear-gradient(90deg, #93c5fd 0%, #60a5fa 100%)"
st.markdown(f"""
""", unsafe_allow_html=True)