raahinaez's picture
Create app.py
3d1c323 verified
import streamlit as st
import fitz # PyMuPDF
import pytesseract
from PIL import Image
import io
import time
import os
from transformers import pipeline, AutoTokenizer
# Try to import fasttext, handle gracefully if not available
try:
import fasttext
FASTTEXT_AVAILABLE = True
except ImportError:
FASTTEXT_AVAILABLE = False
fasttext = None
# Configure Tesseract path
# For Windows: use specific path
# For Linux (HF Spaces): Tesseract is usually in PATH, but we can check common locations
if os.name == 'nt': # Windows
tesseract_path = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
if os.path.exists(tesseract_path):
pytesseract.pytesseract.tesseract_cmd = tesseract_path
else:
# Linux/Unix: Check common Tesseract locations
# On Hugging Face Spaces, Tesseract should be in PATH after installing via packages.txt
common_paths = [
'/usr/bin/tesseract',
'/usr/local/bin/tesseract',
'tesseract' # Try system PATH
]
for path in common_paths:
try:
# Try to find tesseract in PATH or at specific location
import shutil
tesseract_cmd = shutil.which('tesseract') or path
if tesseract_cmd and (os.path.exists(tesseract_cmd) or tesseract_cmd == 'tesseract'):
pytesseract.pytesseract.tesseract_cmd = tesseract_cmd
break
except Exception:
continue
def get_hf_token():
"""
Get Hugging Face token from Streamlit secrets (for HF Spaces) or environment variables (for local).
Priority: Streamlit secrets > Environment variables
"""
# Try Streamlit secrets first (for Hugging Face Spaces deployment)
try:
if hasattr(st, 'secrets') and 'HF_TOKEN' in st.secrets:
return st.secrets['HF_TOKEN']
# Also check nested structure (HF.TOKEN)
if hasattr(st, 'secrets') and 'HF' in st.secrets and 'TOKEN' in st.secrets['HF']:
return st.secrets['HF']['TOKEN']
except Exception:
pass
# Fallback to environment variables (for local development)
return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
# Configure page
st.set_page_config(
page_title="Document Classification Performance Testing",
page_icon="📄",
layout="centered"
)
# Initialize session state
if 'models_loaded' not in st.session_state:
st.session_state.models_loaded = {}
def extract_text_from_pdf(pdf_file):
"""Extract text from PDF using PyMuPDF."""
try:
pdf_bytes = pdf_file.read()
pdf_file.seek(0) # Reset file pointer
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
text = ""
for page in doc:
text += page.get_text()
doc.close()
return text.strip()
except Exception as e:
st.error(f"Error extracting text from PDF: {str(e)}")
return None
def extract_text_from_image(image_file):
"""Extract text from image using Tesseract OCR."""
try:
image_bytes = image_file.read()
image_file.seek(0) # Reset file pointer
image = Image.open(io.BytesIO(image_bytes))
text = pytesseract.image_to_string(image)
return text.strip()
except Exception as e:
st.error(f"Error extracting text from image: {str(e)}")
return None
def load_distilbert_model():
"""Load DistilBERT model for zero-shot classification."""
if 'distilbert' not in st.session_state.models_loaded:
with st.spinner("Loading DistilBERT model (first time may take a while)..."):
try:
model_name = "distilbert-base-uncased"
classifier = pipeline(
"zero-shot-classification",
model=model_name,
truncation=True, # Enable automatic truncation
max_length=512 # Set max length
)
# Also store tokenizer for accurate truncation if needed
tokenizer = AutoTokenizer.from_pretrained(model_name)
st.session_state.models_loaded['distilbert'] = classifier
st.session_state.models_loaded['distilbert_tokenizer'] = tokenizer
except Exception as e:
st.error(f"Error loading DistilBERT model: {str(e)}")
return None
return st.session_state.models_loaded['distilbert']
def load_tinybert_model():
"""Load TinyBERT model for zero-shot classification."""
if 'tinybert' not in st.session_state.models_loaded:
with st.spinner("Loading TinyBERT model (first time may take a while)..."):
try:
# Try alternative TinyBERT model identifiers
# google/tinybert-6L-384D may not exist, try huawei-noah version
model_name = "huawei-noah/TinyBERT_General_6L_768D"
# Get Hugging Face token (works for both local and HF Spaces)
hf_token = get_hf_token()
# Try with token first
try:
classifier = pipeline(
"zero-shot-classification",
model=model_name,
token=hf_token if hf_token else None,
truncation=True,
max_length=512
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=hf_token if hf_token else None
)
except Exception as token_error:
# If that fails, try without token (for public models)
classifier = pipeline(
"zero-shot-classification",
model=model_name,
truncation=True,
max_length=512
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Also store tokenizer for accurate truncation
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token if hf_token else None)
st.session_state.models_loaded['tinybert'] = classifier
st.session_state.models_loaded['tinybert_tokenizer'] = tokenizer
except Exception as e:
error_msg = str(e)
# Try the original model name as fallback
if "huawei-noah" in error_msg.lower() or "not a valid model identifier" in error_msg.lower():
try:
st.info("Trying alternative TinyBERT model identifier...")
model_name = "google/tinybert-6L-384D"
hf_token = get_hf_token()
classifier = pipeline(
"zero-shot-classification",
model=model_name,
token=hf_token if hf_token else None,
truncation=True,
max_length=512
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=hf_token if hf_token else None
)
st.session_state.models_loaded['tinybert'] = classifier
st.session_state.models_loaded['tinybert_tokenizer'] = tokenizer
except Exception as e2:
st.error(
f"Error loading TinyBERT model. Tried both 'huawei-noah/TinyBERT_General_6L_768D' "
f"and 'google/tinybert-6L-384D'. "
f"If these models are private, ensure your Hugging Face token is set in the "
f"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN environment variable. "
f"Details: {str(e2)}"
)
return None
else:
st.error(
f"Error loading TinyBERT model. "
f"If this model is private, ensure your Hugging Face token is set in the "
f"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN environment variable. "
f"Details: {error_msg}"
)
return None
return st.session_state.models_loaded['tinybert']
def load_fasttext_model():
"""Load FastText model."""
if not FASTTEXT_AVAILABLE:
st.error("FastText is not installed. Please install it using: pip install fasttext")
return None
if 'fasttext' not in st.session_state.models_loaded:
with st.spinner("Loading FastText model (first time may take a while)..."):
try:
# Using a pre-trained language identification model as an example
# Note: For production document classification, you would use a custom trained FastText model
# Download the model if not present: https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin
model_path = "lid.176.bin"
if not os.path.exists(model_path):
st.warning("FastText model file not found. Please download 'lid.176.bin' from "
"https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin "
"and place it in the current directory.")
return None
model = fasttext.load_model(model_path)
st.session_state.models_loaded['fasttext'] = model
except Exception as e:
st.error(f"Error loading FastText model: {str(e)}")
return None
return st.session_state.models_loaded['fasttext']
def truncate_text_for_model(text, tokenizer=None, max_length=500):
"""
Truncate text to fit within model's maximum sequence length.
If tokenizer is provided, uses accurate token counting.
Otherwise, uses character-based approximation.
"""
if tokenizer is not None:
# Use tokenizer for accurate truncation
tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length, truncation=True)
if len(tokens) >= max_length:
# Decode back to text (this will be properly truncated)
truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
return truncated_text, True
return text, False
else:
# Fallback: character-based approximation (roughly 4 chars per token)
# Leave buffer for special tokens
max_chars = (max_length - 12) * 4
if len(text) <= max_chars:
return text, False
# Truncate at word boundary
truncated = text[:max_chars].rsplit(' ', 1)[0]
return truncated + "...", True
def classify_with_distilbert(classifier, text, candidate_labels):
"""Classify text using DistilBERT."""
if classifier is None:
return None, None
# Get tokenizer if available for accurate truncation
tokenizer = st.session_state.models_loaded.get('distilbert_tokenizer')
# Truncate text if needed (DistilBERT max length is 512)
truncated_text, was_truncated = truncate_text_for_model(text, tokenizer=tokenizer, max_length=500)
if was_truncated:
st.warning("⚠️ Text was truncated to fit model's maximum input length (512 tokens).")
try:
result = classifier(truncated_text, candidate_labels, truncation=True)
return result['labels'][0], result['scores'][0]
except Exception as e:
st.error(f"Classification error: {str(e)}")
return None, None
def classify_with_tinybert(classifier, text, candidate_labels):
"""Classify text using TinyBERT."""
if classifier is None:
return None, None
# Get tokenizer if available for accurate truncation
tokenizer = st.session_state.models_loaded.get('tinybert_tokenizer')
# Truncate text if needed (TinyBERT max length is 512)
truncated_text, was_truncated = truncate_text_for_model(text, tokenizer=tokenizer, max_length=500)
if was_truncated:
st.warning("⚠️ Text was truncated to fit model's maximum input length (512 tokens).")
try:
result = classifier(truncated_text, candidate_labels, truncation=True)
return result['labels'][0], result['scores'][0]
except Exception as e:
st.error(f"Classification error: {str(e)}")
return None, None
def classify_with_fasttext(model, text):
"""Classify text using FastText."""
if model is None:
return None, None
# FastText language identification as example
# In production, use a custom trained model for document classification
predictions = model.predict(text, k=1)
label = predictions[0][0].replace('__label__', '')
score = float(predictions[1][0])
return label, score
# Main UI
st.title("📄 Document Classification Performance Testing")
st.markdown("---")
# File upload
uploaded_file = st.file_uploader(
"Upload a document",
type=['pdf', 'png', 'jpg', 'jpeg'],
help="Upload a PDF or Image file (PNG/JPG/JPEG)"
)
# Model selection
model_options = [
"distilbert-base-uncased",
"google/tinybert-6L-384D"
]
# FastText option commented out
# if FASTTEXT_AVAILABLE:
# model_options.append("FastText")
# else:
# model_options.append("FastText (Not Available)")
model_option = st.selectbox(
"Select Model",
options=model_options,
help="Choose a model for document classification"
)
# FastText warning commented out
# if model_option == "FastText (Not Available)":
# st.warning("⚠️ FastText is not installed. To use FastText, you can try installing a pre-built wheel: "
# "`pip install fasttext-wheel` or use conda: `conda install -c conda-forge fasttext`. "
# "Alternatively, use DistilBERT or TinyBERT models.")
# Classification button
if st.button("Classify Document", type="primary"):
if uploaded_file is None:
st.warning("Please upload a file first.")
else:
# Extract text based on file type
file_extension = uploaded_file.name.split('.')[-1].lower()
if file_extension == 'pdf':
extracted_text = extract_text_from_pdf(uploaded_file)
elif file_extension in ['png', 'jpg', 'jpeg']:
extracted_text = extract_text_from_image(uploaded_file)
else:
st.error("Unsupported file type.")
extracted_text = None
# Check if text extraction was successful
if extracted_text is None:
st.error("Failed to extract text from the document.")
elif len(extracted_text.strip()) == 0:
st.error("No text could be extracted from the document. The document may be empty or contain only images.")
else:
# Display extracted text preview
with st.expander("Extracted Text Preview", expanded=False):
st.text(extracted_text[:500] + "..." if len(extracted_text) > 500 else extracted_text)
# Define candidate labels for zero-shot classification
# These are example labels - adjust based on your use case
candidate_labels = [
"invoice",
"contract",
"report",
"letter",
"receipt",
"form",
"memo",
"other"
]
# Load model and classify
start_time = time.time()
if model_option == "distilbert-base-uncased":
classifier = load_distilbert_model()
if classifier:
predicted_label, confidence = classify_with_distilbert(
classifier, extracted_text, candidate_labels
)
model_name = "DistilBERT Base Uncased"
else:
predicted_label, confidence = None, None
model_name = "DistilBERT Base Uncased"
elif model_option == "google/tinybert-6L-384D":
classifier = load_tinybert_model()
if classifier:
predicted_label, confidence = classify_with_tinybert(
classifier, extracted_text, candidate_labels
)
model_name = "TinyBERT 6L-384D"
else:
predicted_label, confidence = None, None
model_name = "TinyBERT 6L-384D"
# FastText option commented out
# elif model_option in ["FastText", "FastText (Not Available)"]:
# if not FASTTEXT_AVAILABLE:
# st.error("FastText is not available. Please install it or select a different model.")
# predicted_label, confidence = None, None
# model_name = "FastText"
# else:
# model = load_fasttext_model()
# if model:
# predicted_label, confidence = classify_with_fasttext(model, extracted_text)
# model_name = "FastText"
# else:
# predicted_label, confidence = None, None
# model_name = "FastText"
inference_time = time.time() - start_time
# Display results
if predicted_label is not None:
st.success("Classification Complete!")
st.markdown("---")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Predicted Label", predicted_label)
with col2:
st.metric("Confidence", f"{confidence:.2%}" if confidence else "N/A")
with col3:
st.metric("Inference Time", f"{inference_time:.4f}s")
st.info(f"**Model Used:** {model_name}")
else:
st.error("Classification failed. Please check the model loading status above.")
# Footer
st.markdown("---")
st.caption("Document Classification Performance Testing Tool")