Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import fitz # PyMuPDF | |
| import pytesseract | |
| from PIL import Image | |
| import io | |
| import time | |
| import os | |
| from transformers import pipeline, AutoTokenizer | |
| # Try to import fasttext, handle gracefully if not available | |
| try: | |
| import fasttext | |
| FASTTEXT_AVAILABLE = True | |
| except ImportError: | |
| FASTTEXT_AVAILABLE = False | |
| fasttext = None | |
| # Configure Tesseract path | |
| # For Windows: use specific path | |
| # For Linux (HF Spaces): Tesseract is usually in PATH, but we can check common locations | |
| if os.name == 'nt': # Windows | |
| tesseract_path = r'C:\Program Files\Tesseract-OCR\tesseract.exe' | |
| if os.path.exists(tesseract_path): | |
| pytesseract.pytesseract.tesseract_cmd = tesseract_path | |
| else: | |
| # Linux/Unix: Check common Tesseract locations | |
| # On Hugging Face Spaces, Tesseract should be in PATH after installing via packages.txt | |
| common_paths = [ | |
| '/usr/bin/tesseract', | |
| '/usr/local/bin/tesseract', | |
| 'tesseract' # Try system PATH | |
| ] | |
| for path in common_paths: | |
| try: | |
| # Try to find tesseract in PATH or at specific location | |
| import shutil | |
| tesseract_cmd = shutil.which('tesseract') or path | |
| if tesseract_cmd and (os.path.exists(tesseract_cmd) or tesseract_cmd == 'tesseract'): | |
| pytesseract.pytesseract.tesseract_cmd = tesseract_cmd | |
| break | |
| except Exception: | |
| continue | |
| def get_hf_token(): | |
| """ | |
| Get Hugging Face token from Streamlit secrets (for HF Spaces) or environment variables (for local). | |
| Priority: Streamlit secrets > Environment variables | |
| """ | |
| # Try Streamlit secrets first (for Hugging Face Spaces deployment) | |
| try: | |
| if hasattr(st, 'secrets') and 'HF_TOKEN' in st.secrets: | |
| return st.secrets['HF_TOKEN'] | |
| # Also check nested structure (HF.TOKEN) | |
| if hasattr(st, 'secrets') and 'HF' in st.secrets and 'TOKEN' in st.secrets['HF']: | |
| return st.secrets['HF']['TOKEN'] | |
| except Exception: | |
| pass | |
| # Fallback to environment variables (for local development) | |
| return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| # Configure page | |
| st.set_page_config( | |
| page_title="Document Classification Performance Testing", | |
| page_icon="📄", | |
| layout="centered" | |
| ) | |
| # Initialize session state | |
| if 'models_loaded' not in st.session_state: | |
| st.session_state.models_loaded = {} | |
| def extract_text_from_pdf(pdf_file): | |
| """Extract text from PDF using PyMuPDF.""" | |
| try: | |
| pdf_bytes = pdf_file.read() | |
| pdf_file.seek(0) # Reset file pointer | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| text = "" | |
| for page in doc: | |
| text += page.get_text() | |
| doc.close() | |
| return text.strip() | |
| except Exception as e: | |
| st.error(f"Error extracting text from PDF: {str(e)}") | |
| return None | |
| def extract_text_from_image(image_file): | |
| """Extract text from image using Tesseract OCR.""" | |
| try: | |
| image_bytes = image_file.read() | |
| image_file.seek(0) # Reset file pointer | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| text = pytesseract.image_to_string(image) | |
| return text.strip() | |
| except Exception as e: | |
| st.error(f"Error extracting text from image: {str(e)}") | |
| return None | |
| def load_distilbert_model(): | |
| """Load DistilBERT model for zero-shot classification.""" | |
| if 'distilbert' not in st.session_state.models_loaded: | |
| with st.spinner("Loading DistilBERT model (first time may take a while)..."): | |
| try: | |
| model_name = "distilbert-base-uncased" | |
| classifier = pipeline( | |
| "zero-shot-classification", | |
| model=model_name, | |
| truncation=True, # Enable automatic truncation | |
| max_length=512 # Set max length | |
| ) | |
| # Also store tokenizer for accurate truncation if needed | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| st.session_state.models_loaded['distilbert'] = classifier | |
| st.session_state.models_loaded['distilbert_tokenizer'] = tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading DistilBERT model: {str(e)}") | |
| return None | |
| return st.session_state.models_loaded['distilbert'] | |
| def load_tinybert_model(): | |
| """Load TinyBERT model for zero-shot classification.""" | |
| if 'tinybert' not in st.session_state.models_loaded: | |
| with st.spinner("Loading TinyBERT model (first time may take a while)..."): | |
| try: | |
| # Try alternative TinyBERT model identifiers | |
| # google/tinybert-6L-384D may not exist, try huawei-noah version | |
| model_name = "huawei-noah/TinyBERT_General_6L_768D" | |
| # Get Hugging Face token (works for both local and HF Spaces) | |
| hf_token = get_hf_token() | |
| # Try with token first | |
| try: | |
| classifier = pipeline( | |
| "zero-shot-classification", | |
| model=model_name, | |
| token=hf_token if hf_token else None, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| token=hf_token if hf_token else None | |
| ) | |
| except Exception as token_error: | |
| # If that fails, try without token (for public models) | |
| classifier = pipeline( | |
| "zero-shot-classification", | |
| model=model_name, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Also store tokenizer for accurate truncation | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token if hf_token else None) | |
| st.session_state.models_loaded['tinybert'] = classifier | |
| st.session_state.models_loaded['tinybert_tokenizer'] = tokenizer | |
| except Exception as e: | |
| error_msg = str(e) | |
| # Try the original model name as fallback | |
| if "huawei-noah" in error_msg.lower() or "not a valid model identifier" in error_msg.lower(): | |
| try: | |
| st.info("Trying alternative TinyBERT model identifier...") | |
| model_name = "google/tinybert-6L-384D" | |
| hf_token = get_hf_token() | |
| classifier = pipeline( | |
| "zero-shot-classification", | |
| model=model_name, | |
| token=hf_token if hf_token else None, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| token=hf_token if hf_token else None | |
| ) | |
| st.session_state.models_loaded['tinybert'] = classifier | |
| st.session_state.models_loaded['tinybert_tokenizer'] = tokenizer | |
| except Exception as e2: | |
| st.error( | |
| f"Error loading TinyBERT model. Tried both 'huawei-noah/TinyBERT_General_6L_768D' " | |
| f"and 'google/tinybert-6L-384D'. " | |
| f"If these models are private, ensure your Hugging Face token is set in the " | |
| f"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN environment variable. " | |
| f"Details: {str(e2)}" | |
| ) | |
| return None | |
| else: | |
| st.error( | |
| f"Error loading TinyBERT model. " | |
| f"If this model is private, ensure your Hugging Face token is set in the " | |
| f"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN environment variable. " | |
| f"Details: {error_msg}" | |
| ) | |
| return None | |
| return st.session_state.models_loaded['tinybert'] | |
| def load_fasttext_model(): | |
| """Load FastText model.""" | |
| if not FASTTEXT_AVAILABLE: | |
| st.error("FastText is not installed. Please install it using: pip install fasttext") | |
| return None | |
| if 'fasttext' not in st.session_state.models_loaded: | |
| with st.spinner("Loading FastText model (first time may take a while)..."): | |
| try: | |
| # Using a pre-trained language identification model as an example | |
| # Note: For production document classification, you would use a custom trained FastText model | |
| # Download the model if not present: https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin | |
| model_path = "lid.176.bin" | |
| if not os.path.exists(model_path): | |
| st.warning("FastText model file not found. Please download 'lid.176.bin' from " | |
| "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin " | |
| "and place it in the current directory.") | |
| return None | |
| model = fasttext.load_model(model_path) | |
| st.session_state.models_loaded['fasttext'] = model | |
| except Exception as e: | |
| st.error(f"Error loading FastText model: {str(e)}") | |
| return None | |
| return st.session_state.models_loaded['fasttext'] | |
| def truncate_text_for_model(text, tokenizer=None, max_length=500): | |
| """ | |
| Truncate text to fit within model's maximum sequence length. | |
| If tokenizer is provided, uses accurate token counting. | |
| Otherwise, uses character-based approximation. | |
| """ | |
| if tokenizer is not None: | |
| # Use tokenizer for accurate truncation | |
| tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length, truncation=True) | |
| if len(tokens) >= max_length: | |
| # Decode back to text (this will be properly truncated) | |
| truncated_text = tokenizer.decode(tokens, skip_special_tokens=True) | |
| return truncated_text, True | |
| return text, False | |
| else: | |
| # Fallback: character-based approximation (roughly 4 chars per token) | |
| # Leave buffer for special tokens | |
| max_chars = (max_length - 12) * 4 | |
| if len(text) <= max_chars: | |
| return text, False | |
| # Truncate at word boundary | |
| truncated = text[:max_chars].rsplit(' ', 1)[0] | |
| return truncated + "...", True | |
| def classify_with_distilbert(classifier, text, candidate_labels): | |
| """Classify text using DistilBERT.""" | |
| if classifier is None: | |
| return None, None | |
| # Get tokenizer if available for accurate truncation | |
| tokenizer = st.session_state.models_loaded.get('distilbert_tokenizer') | |
| # Truncate text if needed (DistilBERT max length is 512) | |
| truncated_text, was_truncated = truncate_text_for_model(text, tokenizer=tokenizer, max_length=500) | |
| if was_truncated: | |
| st.warning("⚠️ Text was truncated to fit model's maximum input length (512 tokens).") | |
| try: | |
| result = classifier(truncated_text, candidate_labels, truncation=True) | |
| return result['labels'][0], result['scores'][0] | |
| except Exception as e: | |
| st.error(f"Classification error: {str(e)}") | |
| return None, None | |
| def classify_with_tinybert(classifier, text, candidate_labels): | |
| """Classify text using TinyBERT.""" | |
| if classifier is None: | |
| return None, None | |
| # Get tokenizer if available for accurate truncation | |
| tokenizer = st.session_state.models_loaded.get('tinybert_tokenizer') | |
| # Truncate text if needed (TinyBERT max length is 512) | |
| truncated_text, was_truncated = truncate_text_for_model(text, tokenizer=tokenizer, max_length=500) | |
| if was_truncated: | |
| st.warning("⚠️ Text was truncated to fit model's maximum input length (512 tokens).") | |
| try: | |
| result = classifier(truncated_text, candidate_labels, truncation=True) | |
| return result['labels'][0], result['scores'][0] | |
| except Exception as e: | |
| st.error(f"Classification error: {str(e)}") | |
| return None, None | |
| def classify_with_fasttext(model, text): | |
| """Classify text using FastText.""" | |
| if model is None: | |
| return None, None | |
| # FastText language identification as example | |
| # In production, use a custom trained model for document classification | |
| predictions = model.predict(text, k=1) | |
| label = predictions[0][0].replace('__label__', '') | |
| score = float(predictions[1][0]) | |
| return label, score | |
| # Main UI | |
| st.title("📄 Document Classification Performance Testing") | |
| st.markdown("---") | |
| # File upload | |
| uploaded_file = st.file_uploader( | |
| "Upload a document", | |
| type=['pdf', 'png', 'jpg', 'jpeg'], | |
| help="Upload a PDF or Image file (PNG/JPG/JPEG)" | |
| ) | |
| # Model selection | |
| model_options = [ | |
| "distilbert-base-uncased", | |
| "google/tinybert-6L-384D" | |
| ] | |
| # FastText option commented out | |
| # if FASTTEXT_AVAILABLE: | |
| # model_options.append("FastText") | |
| # else: | |
| # model_options.append("FastText (Not Available)") | |
| model_option = st.selectbox( | |
| "Select Model", | |
| options=model_options, | |
| help="Choose a model for document classification" | |
| ) | |
| # FastText warning commented out | |
| # if model_option == "FastText (Not Available)": | |
| # st.warning("⚠️ FastText is not installed. To use FastText, you can try installing a pre-built wheel: " | |
| # "`pip install fasttext-wheel` or use conda: `conda install -c conda-forge fasttext`. " | |
| # "Alternatively, use DistilBERT or TinyBERT models.") | |
| # Classification button | |
| if st.button("Classify Document", type="primary"): | |
| if uploaded_file is None: | |
| st.warning("Please upload a file first.") | |
| else: | |
| # Extract text based on file type | |
| file_extension = uploaded_file.name.split('.')[-1].lower() | |
| if file_extension == 'pdf': | |
| extracted_text = extract_text_from_pdf(uploaded_file) | |
| elif file_extension in ['png', 'jpg', 'jpeg']: | |
| extracted_text = extract_text_from_image(uploaded_file) | |
| else: | |
| st.error("Unsupported file type.") | |
| extracted_text = None | |
| # Check if text extraction was successful | |
| if extracted_text is None: | |
| st.error("Failed to extract text from the document.") | |
| elif len(extracted_text.strip()) == 0: | |
| st.error("No text could be extracted from the document. The document may be empty or contain only images.") | |
| else: | |
| # Display extracted text preview | |
| with st.expander("Extracted Text Preview", expanded=False): | |
| st.text(extracted_text[:500] + "..." if len(extracted_text) > 500 else extracted_text) | |
| # Define candidate labels for zero-shot classification | |
| # These are example labels - adjust based on your use case | |
| candidate_labels = [ | |
| "invoice", | |
| "contract", | |
| "report", | |
| "letter", | |
| "receipt", | |
| "form", | |
| "memo", | |
| "other" | |
| ] | |
| # Load model and classify | |
| start_time = time.time() | |
| if model_option == "distilbert-base-uncased": | |
| classifier = load_distilbert_model() | |
| if classifier: | |
| predicted_label, confidence = classify_with_distilbert( | |
| classifier, extracted_text, candidate_labels | |
| ) | |
| model_name = "DistilBERT Base Uncased" | |
| else: | |
| predicted_label, confidence = None, None | |
| model_name = "DistilBERT Base Uncased" | |
| elif model_option == "google/tinybert-6L-384D": | |
| classifier = load_tinybert_model() | |
| if classifier: | |
| predicted_label, confidence = classify_with_tinybert( | |
| classifier, extracted_text, candidate_labels | |
| ) | |
| model_name = "TinyBERT 6L-384D" | |
| else: | |
| predicted_label, confidence = None, None | |
| model_name = "TinyBERT 6L-384D" | |
| # FastText option commented out | |
| # elif model_option in ["FastText", "FastText (Not Available)"]: | |
| # if not FASTTEXT_AVAILABLE: | |
| # st.error("FastText is not available. Please install it or select a different model.") | |
| # predicted_label, confidence = None, None | |
| # model_name = "FastText" | |
| # else: | |
| # model = load_fasttext_model() | |
| # if model: | |
| # predicted_label, confidence = classify_with_fasttext(model, extracted_text) | |
| # model_name = "FastText" | |
| # else: | |
| # predicted_label, confidence = None, None | |
| # model_name = "FastText" | |
| inference_time = time.time() - start_time | |
| # Display results | |
| if predicted_label is not None: | |
| st.success("Classification Complete!") | |
| st.markdown("---") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Predicted Label", predicted_label) | |
| with col2: | |
| st.metric("Confidence", f"{confidence:.2%}" if confidence else "N/A") | |
| with col3: | |
| st.metric("Inference Time", f"{inference_time:.4f}s") | |
| st.info(f"**Model Used:** {model_name}") | |
| else: | |
| st.error("Classification failed. Please check the model loading status above.") | |
| # Footer | |
| st.markdown("---") | |
| st.caption("Document Classification Performance Testing Tool") | |