import streamlit as st import fitz # PyMuPDF import pytesseract from PIL import Image import io import time import os from transformers import pipeline, AutoTokenizer # Try to import fasttext, handle gracefully if not available try: import fasttext FASTTEXT_AVAILABLE = True except ImportError: FASTTEXT_AVAILABLE = False fasttext = None # Configure Tesseract path # For Windows: use specific path # For Linux (HF Spaces): Tesseract is usually in PATH, but we can check common locations if os.name == 'nt': # Windows tesseract_path = r'C:\Program Files\Tesseract-OCR\tesseract.exe' if os.path.exists(tesseract_path): pytesseract.pytesseract.tesseract_cmd = tesseract_path else: # Linux/Unix: Check common Tesseract locations # On Hugging Face Spaces, Tesseract should be in PATH after installing via packages.txt common_paths = [ '/usr/bin/tesseract', '/usr/local/bin/tesseract', 'tesseract' # Try system PATH ] for path in common_paths: try: # Try to find tesseract in PATH or at specific location import shutil tesseract_cmd = shutil.which('tesseract') or path if tesseract_cmd and (os.path.exists(tesseract_cmd) or tesseract_cmd == 'tesseract'): pytesseract.pytesseract.tesseract_cmd = tesseract_cmd break except Exception: continue def get_hf_token(): """ Get Hugging Face token from Streamlit secrets (for HF Spaces) or environment variables (for local). Priority: Streamlit secrets > Environment variables """ # Try Streamlit secrets first (for Hugging Face Spaces deployment) try: if hasattr(st, 'secrets') and 'HF_TOKEN' in st.secrets: return st.secrets['HF_TOKEN'] # Also check nested structure (HF.TOKEN) if hasattr(st, 'secrets') and 'HF' in st.secrets and 'TOKEN' in st.secrets['HF']: return st.secrets['HF']['TOKEN'] except Exception: pass # Fallback to environment variables (for local development) return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") # Configure page st.set_page_config( page_title="Document Classification Performance Testing", page_icon="📄", layout="centered" ) # Initialize session state if 'models_loaded' not in st.session_state: st.session_state.models_loaded = {} def extract_text_from_pdf(pdf_file): """Extract text from PDF using PyMuPDF.""" try: pdf_bytes = pdf_file.read() pdf_file.seek(0) # Reset file pointer doc = fitz.open(stream=pdf_bytes, filetype="pdf") text = "" for page in doc: text += page.get_text() doc.close() return text.strip() except Exception as e: st.error(f"Error extracting text from PDF: {str(e)}") return None def extract_text_from_image(image_file): """Extract text from image using Tesseract OCR.""" try: image_bytes = image_file.read() image_file.seek(0) # Reset file pointer image = Image.open(io.BytesIO(image_bytes)) text = pytesseract.image_to_string(image) return text.strip() except Exception as e: st.error(f"Error extracting text from image: {str(e)}") return None def load_distilbert_model(): """Load DistilBERT model for zero-shot classification.""" if 'distilbert' not in st.session_state.models_loaded: with st.spinner("Loading DistilBERT model (first time may take a while)..."): try: model_name = "distilbert-base-uncased" classifier = pipeline( "zero-shot-classification", model=model_name, truncation=True, # Enable automatic truncation max_length=512 # Set max length ) # Also store tokenizer for accurate truncation if needed tokenizer = AutoTokenizer.from_pretrained(model_name) st.session_state.models_loaded['distilbert'] = classifier st.session_state.models_loaded['distilbert_tokenizer'] = tokenizer except Exception as e: st.error(f"Error loading DistilBERT model: {str(e)}") return None return st.session_state.models_loaded['distilbert'] def load_tinybert_model(): """Load TinyBERT model for zero-shot classification.""" if 'tinybert' not in st.session_state.models_loaded: with st.spinner("Loading TinyBERT model (first time may take a while)..."): try: # Try alternative TinyBERT model identifiers # google/tinybert-6L-384D may not exist, try huawei-noah version model_name = "huawei-noah/TinyBERT_General_6L_768D" # Get Hugging Face token (works for both local and HF Spaces) hf_token = get_hf_token() # Try with token first try: classifier = pipeline( "zero-shot-classification", model=model_name, token=hf_token if hf_token else None, truncation=True, max_length=512 ) tokenizer = AutoTokenizer.from_pretrained( model_name, token=hf_token if hf_token else None ) except Exception as token_error: # If that fails, try without token (for public models) classifier = pipeline( "zero-shot-classification", model=model_name, truncation=True, max_length=512 ) tokenizer = AutoTokenizer.from_pretrained(model_name) # Also store tokenizer for accurate truncation tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token if hf_token else None) st.session_state.models_loaded['tinybert'] = classifier st.session_state.models_loaded['tinybert_tokenizer'] = tokenizer except Exception as e: error_msg = str(e) # Try the original model name as fallback if "huawei-noah" in error_msg.lower() or "not a valid model identifier" in error_msg.lower(): try: st.info("Trying alternative TinyBERT model identifier...") model_name = "google/tinybert-6L-384D" hf_token = get_hf_token() classifier = pipeline( "zero-shot-classification", model=model_name, token=hf_token if hf_token else None, truncation=True, max_length=512 ) tokenizer = AutoTokenizer.from_pretrained( model_name, token=hf_token if hf_token else None ) st.session_state.models_loaded['tinybert'] = classifier st.session_state.models_loaded['tinybert_tokenizer'] = tokenizer except Exception as e2: st.error( f"Error loading TinyBERT model. Tried both 'huawei-noah/TinyBERT_General_6L_768D' " f"and 'google/tinybert-6L-384D'. " f"If these models are private, ensure your Hugging Face token is set in the " f"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN environment variable. " f"Details: {str(e2)}" ) return None else: st.error( f"Error loading TinyBERT model. " f"If this model is private, ensure your Hugging Face token is set in the " f"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN environment variable. " f"Details: {error_msg}" ) return None return st.session_state.models_loaded['tinybert'] def load_fasttext_model(): """Load FastText model.""" if not FASTTEXT_AVAILABLE: st.error("FastText is not installed. Please install it using: pip install fasttext") return None if 'fasttext' not in st.session_state.models_loaded: with st.spinner("Loading FastText model (first time may take a while)..."): try: # Using a pre-trained language identification model as an example # Note: For production document classification, you would use a custom trained FastText model # Download the model if not present: https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin model_path = "lid.176.bin" if not os.path.exists(model_path): st.warning("FastText model file not found. Please download 'lid.176.bin' from " "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin " "and place it in the current directory.") return None model = fasttext.load_model(model_path) st.session_state.models_loaded['fasttext'] = model except Exception as e: st.error(f"Error loading FastText model: {str(e)}") return None return st.session_state.models_loaded['fasttext'] def truncate_text_for_model(text, tokenizer=None, max_length=500): """ Truncate text to fit within model's maximum sequence length. If tokenizer is provided, uses accurate token counting. Otherwise, uses character-based approximation. """ if tokenizer is not None: # Use tokenizer for accurate truncation tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length, truncation=True) if len(tokens) >= max_length: # Decode back to text (this will be properly truncated) truncated_text = tokenizer.decode(tokens, skip_special_tokens=True) return truncated_text, True return text, False else: # Fallback: character-based approximation (roughly 4 chars per token) # Leave buffer for special tokens max_chars = (max_length - 12) * 4 if len(text) <= max_chars: return text, False # Truncate at word boundary truncated = text[:max_chars].rsplit(' ', 1)[0] return truncated + "...", True def classify_with_distilbert(classifier, text, candidate_labels): """Classify text using DistilBERT.""" if classifier is None: return None, None # Get tokenizer if available for accurate truncation tokenizer = st.session_state.models_loaded.get('distilbert_tokenizer') # Truncate text if needed (DistilBERT max length is 512) truncated_text, was_truncated = truncate_text_for_model(text, tokenizer=tokenizer, max_length=500) if was_truncated: st.warning("⚠️ Text was truncated to fit model's maximum input length (512 tokens).") try: result = classifier(truncated_text, candidate_labels, truncation=True) return result['labels'][0], result['scores'][0] except Exception as e: st.error(f"Classification error: {str(e)}") return None, None def classify_with_tinybert(classifier, text, candidate_labels): """Classify text using TinyBERT.""" if classifier is None: return None, None # Get tokenizer if available for accurate truncation tokenizer = st.session_state.models_loaded.get('tinybert_tokenizer') # Truncate text if needed (TinyBERT max length is 512) truncated_text, was_truncated = truncate_text_for_model(text, tokenizer=tokenizer, max_length=500) if was_truncated: st.warning("⚠️ Text was truncated to fit model's maximum input length (512 tokens).") try: result = classifier(truncated_text, candidate_labels, truncation=True) return result['labels'][0], result['scores'][0] except Exception as e: st.error(f"Classification error: {str(e)}") return None, None def classify_with_fasttext(model, text): """Classify text using FastText.""" if model is None: return None, None # FastText language identification as example # In production, use a custom trained model for document classification predictions = model.predict(text, k=1) label = predictions[0][0].replace('__label__', '') score = float(predictions[1][0]) return label, score # Main UI st.title("📄 Document Classification Performance Testing") st.markdown("---") # File upload uploaded_file = st.file_uploader( "Upload a document", type=['pdf', 'png', 'jpg', 'jpeg'], help="Upload a PDF or Image file (PNG/JPG/JPEG)" ) # Model selection model_options = [ "distilbert-base-uncased", "google/tinybert-6L-384D" ] # FastText option commented out # if FASTTEXT_AVAILABLE: # model_options.append("FastText") # else: # model_options.append("FastText (Not Available)") model_option = st.selectbox( "Select Model", options=model_options, help="Choose a model for document classification" ) # FastText warning commented out # if model_option == "FastText (Not Available)": # st.warning("⚠️ FastText is not installed. To use FastText, you can try installing a pre-built wheel: " # "`pip install fasttext-wheel` or use conda: `conda install -c conda-forge fasttext`. " # "Alternatively, use DistilBERT or TinyBERT models.") # Classification button if st.button("Classify Document", type="primary"): if uploaded_file is None: st.warning("Please upload a file first.") else: # Extract text based on file type file_extension = uploaded_file.name.split('.')[-1].lower() if file_extension == 'pdf': extracted_text = extract_text_from_pdf(uploaded_file) elif file_extension in ['png', 'jpg', 'jpeg']: extracted_text = extract_text_from_image(uploaded_file) else: st.error("Unsupported file type.") extracted_text = None # Check if text extraction was successful if extracted_text is None: st.error("Failed to extract text from the document.") elif len(extracted_text.strip()) == 0: st.error("No text could be extracted from the document. The document may be empty or contain only images.") else: # Display extracted text preview with st.expander("Extracted Text Preview", expanded=False): st.text(extracted_text[:500] + "..." if len(extracted_text) > 500 else extracted_text) # Define candidate labels for zero-shot classification # These are example labels - adjust based on your use case candidate_labels = [ "invoice", "contract", "report", "letter", "receipt", "form", "memo", "other" ] # Load model and classify start_time = time.time() if model_option == "distilbert-base-uncased": classifier = load_distilbert_model() if classifier: predicted_label, confidence = classify_with_distilbert( classifier, extracted_text, candidate_labels ) model_name = "DistilBERT Base Uncased" else: predicted_label, confidence = None, None model_name = "DistilBERT Base Uncased" elif model_option == "google/tinybert-6L-384D": classifier = load_tinybert_model() if classifier: predicted_label, confidence = classify_with_tinybert( classifier, extracted_text, candidate_labels ) model_name = "TinyBERT 6L-384D" else: predicted_label, confidence = None, None model_name = "TinyBERT 6L-384D" # FastText option commented out # elif model_option in ["FastText", "FastText (Not Available)"]: # if not FASTTEXT_AVAILABLE: # st.error("FastText is not available. Please install it or select a different model.") # predicted_label, confidence = None, None # model_name = "FastText" # else: # model = load_fasttext_model() # if model: # predicted_label, confidence = classify_with_fasttext(model, extracted_text) # model_name = "FastText" # else: # predicted_label, confidence = None, None # model_name = "FastText" inference_time = time.time() - start_time # Display results if predicted_label is not None: st.success("Classification Complete!") st.markdown("---") col1, col2, col3 = st.columns(3) with col1: st.metric("Predicted Label", predicted_label) with col2: st.metric("Confidence", f"{confidence:.2%}" if confidence else "N/A") with col3: st.metric("Inference Time", f"{inference_time:.4f}s") st.info(f"**Model Used:** {model_name}") else: st.error("Classification failed. Please check the model loading status above.") # Footer st.markdown("---") st.caption("Document Classification Performance Testing Tool")