import streamlit as st import fitz # PyMuPDF import torch import time import re import os import torch.nn.functional as F from transformers import ( AutoTokenizer, AutoModelForSequenceClassification ) try: import numpy as np # NumPy 2.x compatibility workaround for FastText # FastText uses np.array(probs, copy=False) which fails in NumPy 2.x # We need to patch numpy.array BEFORE importing fasttext if hasattr(np, '__version__'): np_version = np.__version__.split('.') if len(np_version) > 0 and int(np_version[0]) >= 2: # Monkey-patch np.array to handle copy=False gracefully _original_np_array = np.array def _patched_np_array(obj, dtype=None, copy=True, order='K', subok=False, ndmin=0, **kwargs): # Handle copy=False case which is incompatible with NumPy 2.x # FastText calls: np.array(probs, copy=False) as a keyword argument # Check both kwargs and the copy parameter copy_arg = kwargs.pop('copy', None) if copy_arg is None: copy_arg = copy if copy_arg is False: # Use asarray which doesn't have the copy=False restriction in NumPy 2.x # Note: asarray doesn't accept subok parameter, so we omit it result = np.asarray(obj, dtype=dtype, order=order) if ndmin > 0: # Handle ndmin if specified while result.ndim < ndmin: result = np.expand_dims(result, 0) return result # For copy=True or default, use original function return _original_np_array(obj, dtype=dtype, copy=copy_arg, order=order, subok=subok, ndmin=ndmin, **kwargs) np.array = _patched_np_array import fasttext FASTTEXT_AVAILABLE = True except ImportError: FASTTEXT_AVAILABLE = False np = None # ------------------------------- # Page Configuration # ------------------------------- st.set_page_config( page_title="PDF Document Classification", page_icon="📄", layout="wide", initial_sidebar_state="collapsed" ) # ------------------------------- # Beautiful Blue Theme CSS # ------------------------------- st.markdown(""" """, unsafe_allow_html=True) # ------------------------------- # Header # ------------------------------- st.markdown("""

📄 PDF Document Classification

AI-Powered Document Type Detection • Upload a text-based PDF to classify it as Invoice, Contract, or Other

""", unsafe_allow_html=True) # ------------------------------- # Model Configuration # ------------------------------- MODEL_OPTIONS = { "DistilBERT (distilbert-base-uncased)": "distilbert-base-uncased", "TinyBERT (huawei-noah/TinyBERT_General_6L_768D)": "huawei-noah/TinyBERT_General_6L_768D" } # FastText Configuration FASTTEXT_MODEL_PATH = os.path.join(os.path.dirname(__file__), "doc_classifier.bin") # Fallback to relative path if __file__ is not available if not os.path.exists(FASTTEXT_MODEL_PATH): FASTTEXT_MODEL_PATH = "doc_classifier.bin" FASTTEXT_THRESHOLD = 0.45 # Add FastText to model options if available if FASTTEXT_AVAILABLE: MODEL_OPTIONS["FastText"] = "fasttext" LABELS = { 0: "Invoice", 1: "Contract", 2: "Other" } NUM_LABELS = len(LABELS) # ------------------------------- # Keywords for Rule-Based Classification # ------------------------------- INVOICE_KW = [ "invoice", "remit", "bill to", "net", "tax", "gst", "amount due", "purchase order", "po number" ] CONTRACT_KW = [ "agreement", "contract", "terms and conditions", "party", "clause", "hereby", "effective date" ] # ------------------------------- # Load Model & Tokenizer # ------------------------------- @st.cache_resource def load_model(model_name): if model_name == "fasttext": if not FASTTEXT_AVAILABLE: st.error("FastText is not installed. Please install it using: pip install fasttext") return None, None try: # Check if model file exists if not os.path.exists(FASTTEXT_MODEL_PATH): st.error(f"❌ FastText model file not found: {FASTTEXT_MODEL_PATH}\n\nPlease ensure 'doc_classifier.bin' is in the same directory as streamlit_app.py") return None, None model = fasttext.load_model(FASTTEXT_MODEL_PATH) return None, model # FastText doesn't need a tokenizer except FileNotFoundError: st.error(f"❌ FastText model file not found: {FASTTEXT_MODEL_PATH}\n\nPlease ensure 'doc_classifier.bin' is in the repository root.") return None, None except Exception as e: st.error(f"❌ Error loading FastText model: {str(e)}\n\nPlease check that the model file is valid.") return None, None else: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=NUM_LABELS ) model.eval() return tokenizer, model # ------------------------------- # Text Cleaning (for FastText) # ------------------------------- def clean_text(text: str) -> str: text = text.lower() text = re.sub(r"[^a-z0-9\s]", " ", text) text = re.sub(r"\s+", " ", text) return text.strip() # ------------------------------- # Rule-Based Classification # ------------------------------- def rule_based_classify(text: str) -> str: text_lower = text.lower() if any(k in text_lower for k in INVOICE_KW): return "Invoice" if any(k in text_lower for k in CONTRACT_KW): return "Contract" return "Other" # ------------------------------- # FastText Classification (NumPy 2.x Compatible) # ------------------------------- def classify_with_fasttext(text: str, model): cleaned_text = clean_text(text) try: # Use standard predict method (numpy is already patched if needed) prediction = model.predict(cleaned_text, k=1) labels = prediction[0] # Safely convert probabilities to Python floats probs = prediction[1] if np and isinstance(probs, np.ndarray): probs = [float(p) for p in probs] else: probs = [float(p) for p in probs] label = labels[0].replace("__label__", "").capitalize() confidence = probs[0] if confidence >= FASTTEXT_THRESHOLD and label in ["Invoice", "Contract"]: return label, confidence, "fasttext" final_label = rule_based_classify(text) return final_label, confidence, "rule-based" except (ValueError, TypeError, AttributeError) as e: # Fallback to rule-based if FastText prediction fails (NumPy 2.x compatibility issue) error_msg = str(e) if "copy" in error_msg.lower() or "numpy" in error_msg.lower() or "unable to avoid copy" in error_msg.lower(): st.warning("⚠️ FastText encountered a NumPy 2.x compatibility issue. Using rule-based classification as fallback.") else: st.warning(f"⚠️ FastText prediction encountered an issue, using rule-based classification: {error_msg}") final_label = rule_based_classify(text) return final_label, 0.5, "rule-based" # ------------------------------- # PDF Text Extraction # ------------------------------- def extract_text_from_pdf(pdf_file): doc = fitz.open(stream=pdf_file.read(), filetype="pdf") text = "" for page in doc: text += page.get_text("text") return text.strip() # ------------------------------- # Main UI # ------------------------------- col1, col2 = st.columns(2, gap="large") with col1: st.markdown('

🤖 Model Selection

', unsafe_allow_html=True) selected_model_label = st.selectbox( "Choose your AI model", list(MODEL_OPTIONS.keys()), label_visibility="visible" ) selected_model_name = MODEL_OPTIONS[selected_model_label] tokenizer, model = load_model(selected_model_name) # Check if model loading failed if model is None: st.stop() with col2: st.markdown('

📎 Document Upload

', unsafe_allow_html=True) uploaded_file = st.file_uploader( "Upload your PDF file", type=["pdf"], label_visibility="visible", help="Select a text-based PDF file to classify" ) # ------------------------------- # Processing # ------------------------------- if uploaded_file: st.markdown("---") with st.spinner("🔍 Extracting text from PDF..."): pdf_text = extract_text_from_pdf(uploaded_file) if not pdf_text: st.error("**❌ No Text Found**\n\nThis PDF appears to be image-based. Please upload a text-based PDF file.") else: st.success("**✅ Text Extracted Successfully!**\n\nReady for classification") st.markdown('

📝 Text Preview

', unsafe_allow_html=True) # Display text in a styled container for better visibility preview_text = pdf_text[:2000] if len(pdf_text) > 2000 else pdf_text st.markdown(f"""
{preview_text.replace(chr(10), '
').replace(chr(13), '')}
""", unsafe_allow_html=True) st.caption(f"Showing first 2000 characters of {len(pdf_text)} total characters") st.markdown("---") col_btn1, col_btn2, col_btn3 = st.columns([1, 2, 1]) with col_btn2: classify_btn = st.button("🚀 Classify Document", use_container_width=True) if classify_btn: with st.spinner("🤖 Running AI inference..."): start_time = time.time() # Check if using FastText or Transformer model if selected_model_name == "fasttext": doc_type, confidence, method = classify_with_fasttext(pdf_text, model) conf_percent = confidence * 100 inference_time = time.time() - start_time # Create probability distribution for FastText # For FastText, we only have one prediction, so we'll estimate others prob_dict = {} if doc_type == "Invoice": prob_dict["Invoice"] = conf_percent prob_dict["Contract"] = max(0, (100 - conf_percent) * 0.3) prob_dict["Other"] = max(0, (100 - conf_percent) * 0.7) elif doc_type == "Contract": prob_dict["Contract"] = conf_percent prob_dict["Invoice"] = max(0, (100 - conf_percent) * 0.3) prob_dict["Other"] = max(0, (100 - conf_percent) * 0.7) else: prob_dict["Other"] = conf_percent prob_dict["Invoice"] = max(0, (100 - conf_percent) * 0.4) prob_dict["Contract"] = max(0, (100 - conf_percent) * 0.6) # Normalize probabilities total = sum(prob_dict.values()) prob_dict = {k: (v / total * 100) if total > 0 else 33.33 for k, v in prob_dict.items()} else: # Transformer model classification inputs = tokenizer( pdf_text[:512], return_tensors="pt", truncation=True, padding=True ) with torch.no_grad(): outputs = model(**inputs) inference_time = time.time() - start_time logits = outputs.logits probs = F.softmax(logits, dim=1) confidence, predicted_class = torch.max(probs, dim=1) doc_type = LABELS[predicted_class.item()] conf_percent = confidence.item() * 100 method = "transformer" # Create probability distribution for Transformer prob_dict = {LABELS[i]: probs[0][i].item() * 100 for i in range(len(LABELS))} # Results st.markdown("---") st.markdown('

🎯 Classification Results

', unsafe_allow_html=True) # Result cards col1, col2, col3, col4 = st.columns(4, gap="medium") with col1: st.markdown(f"""

Model

{selected_model_label.split('(')[0].strip()}

""", unsafe_allow_html=True) with col2: st.markdown(f"""

Document Type

{doc_type}

""", unsafe_allow_html=True) with col3: if conf_percent >= 80: conf_gradient = "linear-gradient(135deg, #22c55e 0%, #16a34a 100%)" elif conf_percent >= 60: conf_gradient = "linear-gradient(135deg, #f59e0b 0%, #d97706 100%)" else: conf_gradient = "linear-gradient(135deg, #ef4444 0%, #dc2626 100%)" st.markdown(f"""

Confidence

{conf_percent:.1f}%

""", unsafe_allow_html=True) with col4: method_gradient = "linear-gradient(135deg, #8b5cf6 0%, #7c3aed 100%)" if method == "transformer" else \ "linear-gradient(135deg, #06b6d4 0%, #0891b2 100%)" if method == "fasttext" else \ "linear-gradient(135deg, #f59e0b 0%, #d97706 100%)" st.markdown(f"""

Method

{method}

""", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) # Inference time col_time1, col_time2, col_time3 = st.columns([1, 2, 1]) with col_time2: st.metric("⚡ Inference Time", f"{inference_time:.3f} seconds") # Confidence breakdown st.markdown('

📈 Confidence Breakdown

', unsafe_allow_html=True) sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True) for label, prob in sorted_probs: if label == doc_type: bar_gradient = "linear-gradient(90deg, #3b82f6 0%, #2563eb 100%)" else: bar_gradient = "linear-gradient(90deg, #93c5fd 0%, #60a5fa 100%)" st.markdown(f"""
{label} {prob:.1f}%
""", unsafe_allow_html=True)