import streamlit as st
import fitz # PyMuPDF
import torch
import time
import re
import os
import torch.nn.functional as F
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification
)
try:
import numpy as np
# NumPy 2.x compatibility workaround for FastText
# FastText uses np.array(probs, copy=False) which fails in NumPy 2.x
# We need to patch numpy.array BEFORE importing fasttext
if hasattr(np, '__version__'):
np_version = np.__version__.split('.')
if len(np_version) > 0 and int(np_version[0]) >= 2:
# Monkey-patch np.array to handle copy=False gracefully
_original_np_array = np.array
def _patched_np_array(obj, dtype=None, copy=True, order='K', subok=False, ndmin=0, **kwargs):
# Handle copy=False case which is incompatible with NumPy 2.x
# FastText calls: np.array(probs, copy=False) as a keyword argument
# Check both kwargs and the copy parameter
copy_arg = kwargs.pop('copy', None)
if copy_arg is None:
copy_arg = copy
if copy_arg is False:
# Use asarray which doesn't have the copy=False restriction in NumPy 2.x
# Note: asarray doesn't accept subok parameter, so we omit it
result = np.asarray(obj, dtype=dtype, order=order)
if ndmin > 0:
# Handle ndmin if specified
while result.ndim < ndmin:
result = np.expand_dims(result, 0)
return result
# For copy=True or default, use original function
return _original_np_array(obj, dtype=dtype, copy=copy_arg, order=order, subok=subok, ndmin=ndmin, **kwargs)
np.array = _patched_np_array
import fasttext
FASTTEXT_AVAILABLE = True
except ImportError:
FASTTEXT_AVAILABLE = False
np = None
# -------------------------------
# Page Configuration
# -------------------------------
st.set_page_config(
page_title="PDF Document Classification",
page_icon="📄",
layout="wide",
initial_sidebar_state="collapsed"
)
# -------------------------------
# Beautiful Blue Theme CSS
# -------------------------------
st.markdown("""
""", unsafe_allow_html=True)
# -------------------------------
# Header
# -------------------------------
st.markdown("""
📄 PDF Document Classification
AI-Powered Document Type Detection • Upload a text-based PDF to classify it as Invoice, Contract, or Other
""", unsafe_allow_html=True)
# -------------------------------
# Model Configuration
# -------------------------------
MODEL_OPTIONS = {
"DistilBERT (distilbert-base-uncased)": "distilbert-base-uncased",
"TinyBERT (huawei-noah/TinyBERT_General_6L_768D)": "huawei-noah/TinyBERT_General_6L_768D"
}
# FastText Configuration
FASTTEXT_MODEL_PATH = os.path.join(os.path.dirname(__file__), "doc_classifier.bin")
# Fallback to relative path if __file__ is not available
if not os.path.exists(FASTTEXT_MODEL_PATH):
FASTTEXT_MODEL_PATH = "doc_classifier.bin"
FASTTEXT_THRESHOLD = 0.45
# Add FastText to model options if available
if FASTTEXT_AVAILABLE:
MODEL_OPTIONS["FastText"] = "fasttext"
LABELS = {
0: "Invoice",
1: "Contract",
2: "Other"
}
NUM_LABELS = len(LABELS)
# -------------------------------
# Keywords for Rule-Based Classification
# -------------------------------
INVOICE_KW = [
"invoice", "remit", "bill to", "net", "tax", "gst",
"amount due", "purchase order", "po number"
]
CONTRACT_KW = [
"agreement", "contract", "terms and conditions",
"party", "clause", "hereby", "effective date"
]
# -------------------------------
# Load Model & Tokenizer
# -------------------------------
@st.cache_resource
def load_model(model_name):
if model_name == "fasttext":
if not FASTTEXT_AVAILABLE:
st.error("FastText is not installed. Please install it using: pip install fasttext")
return None, None
try:
# Check if model file exists
if not os.path.exists(FASTTEXT_MODEL_PATH):
st.error(f"❌ FastText model file not found: {FASTTEXT_MODEL_PATH}\n\nPlease ensure 'doc_classifier.bin' is in the same directory as streamlit_app.py")
return None, None
model = fasttext.load_model(FASTTEXT_MODEL_PATH)
return None, model # FastText doesn't need a tokenizer
except FileNotFoundError:
st.error(f"❌ FastText model file not found: {FASTTEXT_MODEL_PATH}\n\nPlease ensure 'doc_classifier.bin' is in the repository root.")
return None, None
except Exception as e:
st.error(f"❌ Error loading FastText model: {str(e)}\n\nPlease check that the model file is valid.")
return None, None
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=NUM_LABELS
)
model.eval()
return tokenizer, model
# -------------------------------
# Text Cleaning (for FastText)
# -------------------------------
def clean_text(text: str) -> str:
text = text.lower()
text = re.sub(r"[^a-z0-9\s]", " ", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
# -------------------------------
# Rule-Based Classification
# -------------------------------
def rule_based_classify(text: str) -> str:
text_lower = text.lower()
if any(k in text_lower for k in INVOICE_KW):
return "Invoice"
if any(k in text_lower for k in CONTRACT_KW):
return "Contract"
return "Other"
# -------------------------------
# FastText Classification (NumPy 2.x Compatible)
# -------------------------------
def classify_with_fasttext(text: str, model):
cleaned_text = clean_text(text)
try:
# Use standard predict method (numpy is already patched if needed)
prediction = model.predict(cleaned_text, k=1)
labels = prediction[0]
# Safely convert probabilities to Python floats
probs = prediction[1]
if np and isinstance(probs, np.ndarray):
probs = [float(p) for p in probs]
else:
probs = [float(p) for p in probs]
label = labels[0].replace("__label__", "").capitalize()
confidence = probs[0]
if confidence >= FASTTEXT_THRESHOLD and label in ["Invoice", "Contract"]:
return label, confidence, "fasttext"
final_label = rule_based_classify(text)
return final_label, confidence, "rule-based"
except (ValueError, TypeError, AttributeError) as e:
# Fallback to rule-based if FastText prediction fails (NumPy 2.x compatibility issue)
error_msg = str(e)
if "copy" in error_msg.lower() or "numpy" in error_msg.lower() or "unable to avoid copy" in error_msg.lower():
st.warning("⚠️ FastText encountered a NumPy 2.x compatibility issue. Using rule-based classification as fallback.")
else:
st.warning(f"⚠️ FastText prediction encountered an issue, using rule-based classification: {error_msg}")
final_label = rule_based_classify(text)
return final_label, 0.5, "rule-based"
# -------------------------------
# PDF Text Extraction
# -------------------------------
def extract_text_from_pdf(pdf_file):
doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
text = ""
for page in doc:
text += page.get_text("text")
return text.strip()
# -------------------------------
# Main UI
# -------------------------------
col1, col2 = st.columns(2, gap="large")
with col1:
st.markdown('', unsafe_allow_html=True)
selected_model_label = st.selectbox(
"Choose your AI model",
list(MODEL_OPTIONS.keys()),
label_visibility="visible"
)
selected_model_name = MODEL_OPTIONS[selected_model_label]
tokenizer, model = load_model(selected_model_name)
# Check if model loading failed
if model is None:
st.stop()
with col2:
st.markdown('', unsafe_allow_html=True)
uploaded_file = st.file_uploader(
"Upload your PDF file",
type=["pdf"],
label_visibility="visible",
help="Select a text-based PDF file to classify"
)
# -------------------------------
# Processing
# -------------------------------
if uploaded_file:
st.markdown("---")
with st.spinner("🔍 Extracting text from PDF..."):
pdf_text = extract_text_from_pdf(uploaded_file)
if not pdf_text:
st.error("**❌ No Text Found**\n\nThis PDF appears to be image-based. Please upload a text-based PDF file.")
else:
st.success("**✅ Text Extracted Successfully!**\n\nReady for classification")
st.markdown('', unsafe_allow_html=True)
# Display text in a styled container for better visibility
preview_text = pdf_text[:2000] if len(pdf_text) > 2000 else pdf_text
st.markdown(f"""
{preview_text.replace(chr(10), '
').replace(chr(13), '')}
""", unsafe_allow_html=True)
st.caption(f"Showing first 2000 characters of {len(pdf_text)} total characters")
st.markdown("---")
col_btn1, col_btn2, col_btn3 = st.columns([1, 2, 1])
with col_btn2:
classify_btn = st.button("🚀 Classify Document", use_container_width=True)
if classify_btn:
with st.spinner("🤖 Running AI inference..."):
start_time = time.time()
# Check if using FastText or Transformer model
if selected_model_name == "fasttext":
doc_type, confidence, method = classify_with_fasttext(pdf_text, model)
conf_percent = confidence * 100
inference_time = time.time() - start_time
# Create probability distribution for FastText
# For FastText, we only have one prediction, so we'll estimate others
prob_dict = {}
if doc_type == "Invoice":
prob_dict["Invoice"] = conf_percent
prob_dict["Contract"] = max(0, (100 - conf_percent) * 0.3)
prob_dict["Other"] = max(0, (100 - conf_percent) * 0.7)
elif doc_type == "Contract":
prob_dict["Contract"] = conf_percent
prob_dict["Invoice"] = max(0, (100 - conf_percent) * 0.3)
prob_dict["Other"] = max(0, (100 - conf_percent) * 0.7)
else:
prob_dict["Other"] = conf_percent
prob_dict["Invoice"] = max(0, (100 - conf_percent) * 0.4)
prob_dict["Contract"] = max(0, (100 - conf_percent) * 0.6)
# Normalize probabilities
total = sum(prob_dict.values())
prob_dict = {k: (v / total * 100) if total > 0 else 33.33 for k, v in prob_dict.items()}
else:
# Transformer model classification
inputs = tokenizer(
pdf_text[:512],
return_tensors="pt",
truncation=True,
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
inference_time = time.time() - start_time
logits = outputs.logits
probs = F.softmax(logits, dim=1)
confidence, predicted_class = torch.max(probs, dim=1)
doc_type = LABELS[predicted_class.item()]
conf_percent = confidence.item() * 100
method = "transformer"
# Create probability distribution for Transformer
prob_dict = {LABELS[i]: probs[0][i].item() * 100 for i in range(len(LABELS))}
# Results
st.markdown("---")
st.markdown('', unsafe_allow_html=True)
# Result cards
col1, col2, col3, col4 = st.columns(4, gap="medium")
with col1:
st.markdown(f"""
Model
{selected_model_label.split('(')[0].strip()}
""", unsafe_allow_html=True)
with col2:
st.markdown(f"""
""", unsafe_allow_html=True)
with col3:
if conf_percent >= 80:
conf_gradient = "linear-gradient(135deg, #22c55e 0%, #16a34a 100%)"
elif conf_percent >= 60:
conf_gradient = "linear-gradient(135deg, #f59e0b 0%, #d97706 100%)"
else:
conf_gradient = "linear-gradient(135deg, #ef4444 0%, #dc2626 100%)"
st.markdown(f"""
Confidence
{conf_percent:.1f}%
""", unsafe_allow_html=True)
with col4:
method_gradient = "linear-gradient(135deg, #8b5cf6 0%, #7c3aed 100%)" if method == "transformer" else \
"linear-gradient(135deg, #06b6d4 0%, #0891b2 100%)" if method == "fasttext" else \
"linear-gradient(135deg, #f59e0b 0%, #d97706 100%)"
st.markdown(f"""
""", unsafe_allow_html=True)
st.markdown("
", unsafe_allow_html=True)
# Inference time
col_time1, col_time2, col_time3 = st.columns([1, 2, 1])
with col_time2:
st.metric("⚡ Inference Time", f"{inference_time:.3f} seconds")
# Confidence breakdown
st.markdown('', unsafe_allow_html=True)
sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True)
for label, prob in sorted_probs:
if label == doc_type:
bar_gradient = "linear-gradient(90deg, #3b82f6 0%, #2563eb 100%)"
else:
bar_gradient = "linear-gradient(90deg, #93c5fd 0%, #60a5fa 100%)"
st.markdown(f"""
""", unsafe_allow_html=True)