import streamlit as st import fitz # PyMuPDF import torch import time import torch.nn.functional as F from transformers import ( AutoTokenizer, AutoModelForSequenceClassification ) # ------------------------------- # Page Configuration # ------------------------------- st.set_page_config( page_title="PDF Document Classification", page_icon="📄", layout="wide", initial_sidebar_state="collapsed" ) # ------------------------------- # Beautiful Blue Theme CSS # ------------------------------- st.markdown(""" """, unsafe_allow_html=True) # ------------------------------- # Header # ------------------------------- st.markdown("""

📄 PDF Document Classification

AI-Powered Document Type Detection • Upload a text-based PDF to classify it as Invoice, Contract, or Other

""", unsafe_allow_html=True) # ------------------------------- # Model Configuration # ------------------------------- MODEL_OPTIONS = { "DistilBERT (distilbert-base-uncased)": "distilbert-base-uncased", "TinyBERT (huawei-noah/TinyBERT_General_6L_768D)": "huawei-noah/TinyBERT_General_6L_768D" } LABELS = { 0: "Invoice", 1: "Contract", 2: "Other" } NUM_LABELS = len(LABELS) # ------------------------------- # Load Model & Tokenizer # ------------------------------- @st.cache_resource def load_model(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=NUM_LABELS ) model.eval() return tokenizer, model # ------------------------------- # PDF Text Extraction # ------------------------------- def extract_text_from_pdf(pdf_file): doc = fitz.open(stream=pdf_file.read(), filetype="pdf") text = "" for page in doc: text += page.get_text("text") return text.strip() # ------------------------------- # Main UI # ------------------------------- col1, col2 = st.columns(2, gap="large") with col1: st.markdown('

🤖 Model Selection

', unsafe_allow_html=True) selected_model_label = st.selectbox( "Choose your AI model", list(MODEL_OPTIONS.keys()), label_visibility="visible" ) selected_model_name = MODEL_OPTIONS[selected_model_label] tokenizer, model = load_model(selected_model_name) with col2: st.markdown('

📎 Document Upload

', unsafe_allow_html=True) uploaded_file = st.file_uploader( "Upload your PDF file", type=["pdf"], label_visibility="visible", help="Select a text-based PDF file to classify" ) # ------------------------------- # Processing # ------------------------------- if uploaded_file: st.markdown("---") with st.spinner("🔍 Extracting text from PDF..."): pdf_text = extract_text_from_pdf(uploaded_file) if not pdf_text: st.error("**❌ No Text Found**\n\nThis PDF appears to be image-based. Please upload a text-based PDF file.") else: st.success("**✅ Text Extracted Successfully!**\n\nReady for classification") st.markdown('

📝 Text Preview

', unsafe_allow_html=True) # Display text in a styled container for better visibility preview_text = pdf_text[:2000] if len(pdf_text) > 2000 else pdf_text st.markdown(f"""
{preview_text.replace(chr(10), '
').replace(chr(13), '')}
""", unsafe_allow_html=True) st.caption(f"Showing first 2000 characters of {len(pdf_text)} total characters") st.markdown("---") col_btn1, col_btn2, col_btn3 = st.columns([1, 2, 1]) with col_btn2: classify_btn = st.button("🚀 Classify Document", use_container_width=True) if classify_btn: with st.spinner("🤖 Running AI inference..."): start_time = time.time() inputs = tokenizer( pdf_text[:512], return_tensors="pt", truncation=True, padding=True ) with torch.no_grad(): outputs = model(**inputs) inference_time = time.time() - start_time logits = outputs.logits probs = F.softmax(logits, dim=1) confidence, predicted_class = torch.max(probs, dim=1) # Results st.markdown("---") st.markdown('

🎯 Classification Results

', unsafe_allow_html=True) doc_type = LABELS[predicted_class.item()] conf_percent = confidence.item() * 100 # Result cards col1, col2, col3 = st.columns(3, gap="medium") with col1: st.markdown(f"""

Model

{selected_model_label.split('(')[0].strip()}

""", unsafe_allow_html=True) with col2: st.markdown(f"""

Document Type

{doc_type}

""", unsafe_allow_html=True) with col3: if conf_percent >= 80: conf_gradient = "linear-gradient(135deg, #22c55e 0%, #16a34a 100%)" elif conf_percent >= 60: conf_gradient = "linear-gradient(135deg, #f59e0b 0%, #d97706 100%)" else: conf_gradient = "linear-gradient(135deg, #ef4444 0%, #dc2626 100%)" st.markdown(f"""

Confidence

{conf_percent:.1f}%

""", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) # Inference time col_time1, col_time2, col_time3 = st.columns([1, 2, 1]) with col_time2: st.metric("⚡ Inference Time", f"{inference_time:.3f} seconds") # Confidence breakdown st.markdown('

📈 Confidence Breakdown

', unsafe_allow_html=True) prob_dict = {LABELS[i]: probs[0][i].item() * 100 for i in range(len(LABELS))} sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True) for label, prob in sorted_probs: if label == doc_type: bar_gradient = "linear-gradient(90deg, #3b82f6 0%, #2563eb 100%)" else: bar_gradient = "linear-gradient(90deg, #93c5fd 0%, #60a5fa 100%)" st.markdown(f"""
{label} {prob:.1f}%
""", unsafe_allow_html=True)