Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import fitz # PyMuPDF | |
| import torch | |
| import time | |
| import re | |
| import os | |
| import torch.nn.functional as F | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification | |
| ) | |
| try: | |
| import numpy as np | |
| # NumPy 2.x compatibility workaround for FastText | |
| # FastText uses np.array(probs, copy=False) which fails in NumPy 2.x | |
| # We need to patch numpy.array BEFORE importing fasttext | |
| if hasattr(np, '__version__'): | |
| np_version = np.__version__.split('.') | |
| if len(np_version) > 0 and int(np_version[0]) >= 2: | |
| # Monkey-patch np.array to handle copy=False gracefully | |
| _original_np_array = np.array | |
| def _patched_np_array(obj, dtype=None, copy=True, order='K', subok=False, ndmin=0, **kwargs): | |
| # Handle copy=False case which is incompatible with NumPy 2.x | |
| # FastText calls: np.array(probs, copy=False) as a keyword argument | |
| # Check both kwargs and the copy parameter | |
| copy_arg = kwargs.pop('copy', None) | |
| if copy_arg is None: | |
| copy_arg = copy | |
| if copy_arg is False: | |
| # Use asarray which doesn't have the copy=False restriction in NumPy 2.x | |
| # Note: asarray doesn't accept subok parameter, so we omit it | |
| result = np.asarray(obj, dtype=dtype, order=order) | |
| if ndmin > 0: | |
| # Handle ndmin if specified | |
| while result.ndim < ndmin: | |
| result = np.expand_dims(result, 0) | |
| return result | |
| # For copy=True or default, use original function | |
| return _original_np_array(obj, dtype=dtype, copy=copy_arg, order=order, subok=subok, ndmin=ndmin, **kwargs) | |
| np.array = _patched_np_array | |
| import fasttext | |
| FASTTEXT_AVAILABLE = True | |
| except ImportError: | |
| FASTTEXT_AVAILABLE = False | |
| np = None | |
| # ------------------------------- | |
| # Page Configuration | |
| # ------------------------------- | |
| st.set_page_config( | |
| page_title="PDF Document Classification", | |
| page_icon="📄", | |
| layout="wide", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| # ------------------------------- | |
| # Beautiful Blue Theme CSS | |
| # ------------------------------- | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;500;600;700;800&display=swap'); | |
| * { | |
| font-family: 'Poppins', sans-serif; | |
| } | |
| .stApp { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 25%, #667eea 50%, #4facfe 75%, #00f2fe 100%); | |
| background-size: 400% 400%; | |
| animation: gradientShift 15s ease infinite; | |
| } | |
| @keyframes gradientShift { | |
| 0% { background-position: 0% 50%; } | |
| 50% { background-position: 100% 50%; } | |
| 100% { background-position: 0% 50%; } | |
| } | |
| .main .block-container { | |
| max-width: 1100px; | |
| padding-top: 2rem; | |
| padding-bottom: 3rem; | |
| } | |
| /* Header Styling */ | |
| h1 { | |
| font-size: 3rem !important; | |
| font-weight: 800 !important; | |
| color: #ffffff !important; | |
| text-align: center; | |
| margin-bottom: 0.5rem !important; | |
| text-shadow: 2px 2px 8px rgba(0,0,0,0.2); | |
| letter-spacing: -0.02em; | |
| } | |
| h2 { | |
| font-size: 1.75rem !important; | |
| font-weight: 700 !important; | |
| color: #1e3a8a !important; | |
| margin-top: 2rem !important; | |
| margin-bottom: 1rem !important; | |
| } | |
| h3 { | |
| font-size: 1.25rem !important; | |
| font-weight: 600 !important; | |
| color: #1e40af !important; | |
| margin-bottom: 0.75rem !important; | |
| } | |
| /* Card Styling */ | |
| .blue-card { | |
| background: linear-gradient(135deg, #ffffff 0%, #f0f9ff 100%); | |
| border-radius: 16px; | |
| padding: 2rem; | |
| box-shadow: 0 10px 30px rgba(30, 58, 138, 0.2); | |
| border: 2px solid rgba(59, 130, 246, 0.3); | |
| margin-bottom: 1.5rem; | |
| backdrop-filter: blur(10px); | |
| transition: all 0.3s ease; | |
| } | |
| .blue-card:hover { | |
| transform: translateY(-5px); | |
| box-shadow: 0 15px 40px rgba(30, 58, 138, 0.3); | |
| } | |
| /* Selectbox Styling */ | |
| .stSelectbox > div > div { | |
| background: linear-gradient(135deg, #ffffff 0%, #f8fafc 100%) !important; | |
| border: 2px solid #3b82f6 !important; | |
| border-radius: 12px !important; | |
| color: #1e3a8a !important; | |
| } | |
| .stSelectbox label { | |
| font-weight: 600 !important; | |
| color: #1e40af !important; | |
| font-size: 0.9375rem !important; | |
| margin-bottom: 0.75rem !important; | |
| } | |
| .stSelectbox [data-baseweb="select"] { | |
| color: #1e3a8a !important; | |
| font-weight: 500 !important; | |
| } | |
| /* File Uploader Styling */ | |
| .stFileUploader { | |
| border: 3px dashed #3b82f6 !important; | |
| border-radius: 16px !important; | |
| padding: 2rem !important; | |
| background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(147, 197, 253, 0.1) 100%) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .stFileUploader:hover { | |
| border-color: #2563eb !important; | |
| background: linear-gradient(135deg, rgba(59, 130, 246, 0.2) 0%, rgba(147, 197, 253, 0.2) 100%) !important; | |
| transform: scale(1.02); | |
| } | |
| .stFileUploader label { | |
| font-weight: 600 !important; | |
| color: #1e40af !important; | |
| font-size: 0.9375rem !important; | |
| margin-bottom: 0.75rem !important; | |
| } | |
| /* Button Styling */ | |
| .stButton > button { | |
| background: linear-gradient(135deg, #3b82f6 0%, #2563eb 50%, #1d4ed8 100%) !important; | |
| color: white !important; | |
| font-weight: 700 !important; | |
| font-size: 1rem !important; | |
| padding: 1rem 2rem !important; | |
| border-radius: 12px !important; | |
| border: none !important; | |
| width: 100% !important; | |
| transition: all 0.3s ease !important; | |
| box-shadow: 0 8px 20px rgba(59, 130, 246, 0.4) !important; | |
| text-transform: uppercase; | |
| letter-spacing: 0.5px; | |
| } | |
| .stButton > button:hover { | |
| background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 50%, #1e40af 100%) !important; | |
| box-shadow: 0 12px 30px rgba(59, 130, 246, 0.6) !important; | |
| transform: translateY(-3px) scale(1.02) !important; | |
| } | |
| /* Text Area Styling */ | |
| .stTextArea textarea { | |
| font-family: 'Monaco', 'Menlo', monospace !important; | |
| font-size: 0.875rem !important; | |
| background: #ffffff !important; | |
| border: 2px solid #3b82f6 !important; | |
| border-radius: 12px !important; | |
| color: #1e3a8a !important; | |
| padding: 1rem !important; | |
| } | |
| .stTextArea textarea:disabled { | |
| background: #f8fafc !important; | |
| color: #1e3a8a !important; | |
| -webkit-text-fill-color: #1e3a8a !important; | |
| opacity: 1 !important; | |
| } | |
| .stTextArea textarea::placeholder { | |
| color: #94a3b8 !important; | |
| } | |
| .stTextArea label { | |
| font-weight: 600 !important; | |
| color: #1e40af !important; | |
| font-size: 0.9375rem !important; | |
| } | |
| /* Ensure textarea content is visible */ | |
| textarea[disabled] { | |
| color: #1e3a8a !important; | |
| -webkit-text-fill-color: #1e3a8a !important; | |
| } | |
| /* Force text color in textarea */ | |
| .stTextArea textarea, | |
| .stTextArea textarea[disabled], | |
| .stTextArea textarea:disabled { | |
| color: #1e3a8a !important; | |
| -webkit-text-fill-color: #1e3a8a !important; | |
| } | |
| /* Additional textarea visibility fixes */ | |
| textarea { | |
| color: #1e3a8a !important; | |
| } | |
| textarea[disabled] { | |
| color: #1e3a8a !important; | |
| -webkit-text-fill-color: #1e3a8a !important; | |
| opacity: 1 !important; | |
| } | |
| /* Target Streamlit's textarea wrapper */ | |
| [data-testid="stTextArea"] textarea, | |
| [data-testid="stTextArea"] textarea[disabled] { | |
| color: #1e3a8a !important; | |
| -webkit-text-fill-color: #1e3a8a !important; | |
| } | |
| /* Ensure text content is visible */ | |
| .stTextArea textarea::value, | |
| .stTextArea textarea::content { | |
| color: #1e3a8a !important; | |
| } | |
| /* Metric Styling */ | |
| .stMetric { | |
| background: linear-gradient(135deg, #ffffff 0%, #eff6ff 100%) !important; | |
| padding: 1.5rem !important; | |
| border-radius: 16px !important; | |
| border: 2px solid rgba(59, 130, 246, 0.3) !important; | |
| box-shadow: 0 8px 20px rgba(30, 58, 138, 0.15) !important; | |
| } | |
| .stMetric label { | |
| font-weight: 600 !important; | |
| color: #3b82f6 !important; | |
| font-size: 0.8125rem !important; | |
| text-transform: uppercase; | |
| letter-spacing: 0.1em; | |
| } | |
| .stMetric [data-testid="stMetricValue"] { | |
| font-weight: 800 !important; | |
| color: #1e3a8a !important; | |
| font-size: 2rem !important; | |
| margin-top: 0.5rem !important; | |
| } | |
| /* Success/Error Messages */ | |
| .stSuccess { | |
| background: linear-gradient(135deg, rgba(34, 197, 94, 0.15) 0%, rgba(16, 185, 129, 0.15) 100%) !important; | |
| border-left: 5px solid #22c55e !important; | |
| border-radius: 12px !important; | |
| padding: 1.25rem !important; | |
| color: #065f46 !important; | |
| font-weight: 500 !important; | |
| } | |
| .stError { | |
| background: linear-gradient(135deg, rgba(239, 68, 68, 0.15) 0%, rgba(220, 38, 38, 0.15) 100%) !important; | |
| border-left: 5px solid #ef4444 !important; | |
| border-radius: 12px !important; | |
| padding: 1.25rem !important; | |
| color: #991b1b !important; | |
| font-weight: 500 !important; | |
| } | |
| /* Result Cards */ | |
| .result-card { | |
| background: linear-gradient(135deg, #3b82f6 0%, #2563eb 50%, #1d4ed8 100%); | |
| color: white; | |
| padding: 2rem; | |
| border-radius: 16px; | |
| text-align: center; | |
| box-shadow: 0 10px 30px rgba(30, 58, 138, 0.4); | |
| transition: all 0.3s ease; | |
| border: 2px solid rgba(255, 255, 255, 0.2); | |
| } | |
| .result-card:hover { | |
| transform: translateY(-5px) scale(1.05); | |
| box-shadow: 0 15px 40px rgba(30, 58, 138, 0.5); | |
| } | |
| .result-card h4 { | |
| font-size: 0.75rem; | |
| font-weight: 700; | |
| text-transform: uppercase; | |
| letter-spacing: 0.15em; | |
| opacity: 0.95; | |
| margin-bottom: 0.75rem; | |
| } | |
| .result-card p { | |
| font-size: 1.75rem; | |
| font-weight: 800; | |
| margin: 0; | |
| text-shadow: 1px 1px 3px rgba(0,0,0,0.2); | |
| } | |
| /* Progress Bars */ | |
| .progress-wrapper { | |
| margin-bottom: 1.5rem; | |
| background: linear-gradient(135deg, #ffffff 0%, #f0f9ff 100%); | |
| padding: 1.25rem; | |
| border-radius: 12px; | |
| border: 2px solid rgba(59, 130, 246, 0.2); | |
| box-shadow: 0 4px 10px rgba(30, 58, 138, 0.1); | |
| } | |
| .progress-label { | |
| display: flex; | |
| justify-content: space-between; | |
| margin-bottom: 0.75rem; | |
| font-size: 0.9375rem; | |
| font-weight: 600; | |
| color: #1e40af; | |
| } | |
| .progress-bar { | |
| height: 12px; | |
| background: #e0e7ff; | |
| border-radius: 8px; | |
| overflow: hidden; | |
| box-shadow: inset 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .progress-fill { | |
| height: 100%; | |
| border-radius: 8px; | |
| transition: width 0.5s ease; | |
| box-shadow: 0 2px 8px rgba(59, 130, 246, 0.4); | |
| } | |
| /* Divider */ | |
| hr { | |
| border: none; | |
| border-top: 3px solid rgba(255, 255, 255, 0.3); | |
| margin: 2.5rem 0; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| /* Spinner */ | |
| .stSpinner > div { | |
| border-color: #3b82f6; | |
| } | |
| /* Text Colors - Ensure text in white containers is dark, not white */ | |
| .blue-card, .section-header, .progress-wrapper, .stMetric, | |
| .blue-card p, .section-header p, .progress-wrapper p, .stMetric p, | |
| .blue-card span, .section-header span, .progress-wrapper span, .stMetric span, | |
| .blue-card div, .section-header div, .progress-wrapper div, .stMetric div, | |
| .blue-card *, .section-header *, .progress-wrapper *, .stMetric * { | |
| color: #1e3a8a !important; | |
| } | |
| /* Text on gradient background (not in white containers) should be white */ | |
| .main p:not(.blue-card p):not(.section-header p):not(.progress-wrapper p):not(.stMetric p), | |
| .main span:not(.blue-card span):not(.section-header span):not(.progress-wrapper span):not(.stMetric span) { | |
| color: #ffffff !important; | |
| } | |
| /* Caption text should be visible */ | |
| .main [data-testid="stCaption"] { | |
| color: rgba(255, 255, 255, 0.8) !important; | |
| } | |
| /* Strong text on gradient background */ | |
| .main strong:not(.blue-card strong):not(.section-header strong):not(.progress-wrapper strong):not(.stMetric strong), | |
| .main b:not(.blue-card b):not(.section-header b):not(.progress-wrapper b):not(.stMetric b) { | |
| color: #ffffff !important; | |
| font-weight: 700 !important; | |
| text-shadow: 1px 1px 3px rgba(0,0,0,0.2); | |
| } | |
| /* Override for any nested elements in white containers */ | |
| .blue-card strong, .section-header strong, .progress-wrapper strong, .stMetric strong, | |
| .blue-card b, .section-header b, .progress-wrapper b, .stMetric b { | |
| color: #1e3a8a !important; | |
| } | |
| /* Section Headers */ | |
| .section-header { | |
| background: linear-gradient(135deg, rgba(255, 255, 255, 0.95) 0%, rgba(240, 249, 255, 0.95) 100%); | |
| padding: 1.5rem; | |
| border-radius: 12px; | |
| border: 2px solid rgba(59, 130, 246, 0.3); | |
| margin-bottom: 1rem; | |
| box-shadow: 0 4px 15px rgba(30, 58, 138, 0.15); | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ------------------------------- | |
| # Header | |
| # ------------------------------- | |
| st.markdown(""" | |
| <div style="text-align: center; margin-bottom: 3rem;"> | |
| <h1>📄 PDF Document Classification</h1> | |
| <p style="color: rgba(255,255,255,0.95) !important; font-size: 1.125rem !important; font-weight: 400 !important; margin-top: 0.5rem; text-shadow: 1px 1px 4px rgba(0,0,0,0.2);"> | |
| <strong style="font-weight: 600;">AI-Powered Document Type Detection</strong> • Upload a text-based PDF to classify it as Invoice, Contract, or Other | |
| </p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # ------------------------------- | |
| # Model Configuration | |
| # ------------------------------- | |
| MODEL_OPTIONS = { | |
| "DistilBERT (distilbert-base-uncased)": "distilbert-base-uncased", | |
| "TinyBERT (huawei-noah/TinyBERT_General_6L_768D)": "huawei-noah/TinyBERT_General_6L_768D" | |
| } | |
| # FastText Configuration | |
| FASTTEXT_MODEL_PATH = os.path.join(os.path.dirname(__file__), "doc_classifier.bin") | |
| # Fallback to relative path if __file__ is not available | |
| if not os.path.exists(FASTTEXT_MODEL_PATH): | |
| FASTTEXT_MODEL_PATH = "doc_classifier.bin" | |
| FASTTEXT_THRESHOLD = 0.45 | |
| # Add FastText to model options if available | |
| if FASTTEXT_AVAILABLE: | |
| MODEL_OPTIONS["FastText"] = "fasttext" | |
| LABELS = { | |
| 0: "Invoice", | |
| 1: "Contract", | |
| 2: "Other" | |
| } | |
| NUM_LABELS = len(LABELS) | |
| # ------------------------------- | |
| # Keywords for Rule-Based Classification | |
| # ------------------------------- | |
| INVOICE_KW = [ | |
| "invoice", "remit", "bill to", "net", "tax", "gst", | |
| "amount due", "purchase order", "po number" | |
| ] | |
| CONTRACT_KW = [ | |
| "agreement", "contract", "terms and conditions", | |
| "party", "clause", "hereby", "effective date" | |
| ] | |
| # ------------------------------- | |
| # Load Model & Tokenizer | |
| # ------------------------------- | |
| def load_model(model_name): | |
| if model_name == "fasttext": | |
| if not FASTTEXT_AVAILABLE: | |
| st.error("FastText is not installed. Please install it using: pip install fasttext") | |
| return None, None | |
| try: | |
| # Check if model file exists | |
| if not os.path.exists(FASTTEXT_MODEL_PATH): | |
| st.error(f"❌ FastText model file not found: {FASTTEXT_MODEL_PATH}\n\nPlease ensure 'doc_classifier.bin' is in the same directory as streamlit_app.py") | |
| return None, None | |
| model = fasttext.load_model(FASTTEXT_MODEL_PATH) | |
| return None, model # FastText doesn't need a tokenizer | |
| except FileNotFoundError: | |
| st.error(f"❌ FastText model file not found: {FASTTEXT_MODEL_PATH}\n\nPlease ensure 'doc_classifier.bin' is in the repository root.") | |
| return None, None | |
| except Exception as e: | |
| st.error(f"❌ Error loading FastText model: {str(e)}\n\nPlease check that the model file is valid.") | |
| return None, None | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| num_labels=NUM_LABELS | |
| ) | |
| model.eval() | |
| return tokenizer, model | |
| # ------------------------------- | |
| # Text Cleaning (for FastText) | |
| # ------------------------------- | |
| def clean_text(text: str) -> str: | |
| text = text.lower() | |
| text = re.sub(r"[^a-z0-9\s]", " ", text) | |
| text = re.sub(r"\s+", " ", text) | |
| return text.strip() | |
| # ------------------------------- | |
| # Rule-Based Classification | |
| # ------------------------------- | |
| def rule_based_classify(text: str) -> str: | |
| text_lower = text.lower() | |
| if any(k in text_lower for k in INVOICE_KW): | |
| return "Invoice" | |
| if any(k in text_lower for k in CONTRACT_KW): | |
| return "Contract" | |
| return "Other" | |
| # ------------------------------- | |
| # FastText Classification (NumPy 2.x Compatible) | |
| # ------------------------------- | |
| def classify_with_fasttext(text: str, model): | |
| cleaned_text = clean_text(text) | |
| try: | |
| # Use standard predict method (numpy is already patched if needed) | |
| prediction = model.predict(cleaned_text, k=1) | |
| labels = prediction[0] | |
| # Safely convert probabilities to Python floats | |
| probs = prediction[1] | |
| if np and isinstance(probs, np.ndarray): | |
| probs = [float(p) for p in probs] | |
| else: | |
| probs = [float(p) for p in probs] | |
| label = labels[0].replace("__label__", "").capitalize() | |
| confidence = probs[0] | |
| if confidence >= FASTTEXT_THRESHOLD and label in ["Invoice", "Contract"]: | |
| return label, confidence, "fasttext" | |
| final_label = rule_based_classify(text) | |
| return final_label, confidence, "rule-based" | |
| except (ValueError, TypeError, AttributeError) as e: | |
| # Fallback to rule-based if FastText prediction fails (NumPy 2.x compatibility issue) | |
| error_msg = str(e) | |
| if "copy" in error_msg.lower() or "numpy" in error_msg.lower() or "unable to avoid copy" in error_msg.lower(): | |
| st.warning("⚠️ FastText encountered a NumPy 2.x compatibility issue. Using rule-based classification as fallback.") | |
| else: | |
| st.warning(f"⚠️ FastText prediction encountered an issue, using rule-based classification: {error_msg}") | |
| final_label = rule_based_classify(text) | |
| return final_label, 0.5, "rule-based" | |
| # ------------------------------- | |
| # PDF Text Extraction | |
| # ------------------------------- | |
| def extract_text_from_pdf(pdf_file): | |
| doc = fitz.open(stream=pdf_file.read(), filetype="pdf") | |
| text = "" | |
| for page in doc: | |
| text += page.get_text("text") | |
| return text.strip() | |
| # ------------------------------- | |
| # Main UI | |
| # ------------------------------- | |
| col1, col2 = st.columns(2, gap="large") | |
| with col1: | |
| st.markdown('<div class="section-header"><h3 style="margin: 0; color: #1e40af !important;">🤖 Model Selection</h3></div>', unsafe_allow_html=True) | |
| selected_model_label = st.selectbox( | |
| "Choose your AI model", | |
| list(MODEL_OPTIONS.keys()), | |
| label_visibility="visible" | |
| ) | |
| selected_model_name = MODEL_OPTIONS[selected_model_label] | |
| tokenizer, model = load_model(selected_model_name) | |
| # Check if model loading failed | |
| if model is None: | |
| st.stop() | |
| with col2: | |
| st.markdown('<div class="section-header"><h3 style="margin: 0; color: #1e40af !important;">📎 Document Upload</h3></div>', unsafe_allow_html=True) | |
| uploaded_file = st.file_uploader( | |
| "Upload your PDF file", | |
| type=["pdf"], | |
| label_visibility="visible", | |
| help="Select a text-based PDF file to classify" | |
| ) | |
| # ------------------------------- | |
| # Processing | |
| # ------------------------------- | |
| if uploaded_file: | |
| st.markdown("---") | |
| with st.spinner("🔍 Extracting text from PDF..."): | |
| pdf_text = extract_text_from_pdf(uploaded_file) | |
| if not pdf_text: | |
| st.error("**❌ No Text Found**\n\nThis PDF appears to be image-based. Please upload a text-based PDF file.") | |
| else: | |
| st.success("**✅ Text Extracted Successfully!**\n\nReady for classification") | |
| st.markdown('<div class="section-header"><h3 style="margin: 0; color: #1e40af !important;">📝 Text Preview</h3></div>', unsafe_allow_html=True) | |
| # Display text in a styled container for better visibility | |
| preview_text = pdf_text[:2000] if len(pdf_text) > 2000 else pdf_text | |
| st.markdown(f""" | |
| <div style="background: #ffffff; border: 2px solid #3b82f6; border-radius: 12px; padding: 1.5rem; max-height: 300px; overflow-y: auto; font-family: 'Monaco', 'Menlo', monospace; font-size: 0.875rem; color: #1e3a8a; line-height: 1.6;"> | |
| {preview_text.replace(chr(10), '<br>').replace(chr(13), '')} | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.caption(f"Showing first 2000 characters of {len(pdf_text)} total characters") | |
| st.markdown("---") | |
| col_btn1, col_btn2, col_btn3 = st.columns([1, 2, 1]) | |
| with col_btn2: | |
| classify_btn = st.button("🚀 Classify Document", use_container_width=True) | |
| if classify_btn: | |
| with st.spinner("🤖 Running AI inference..."): | |
| start_time = time.time() | |
| # Check if using FastText or Transformer model | |
| if selected_model_name == "fasttext": | |
| doc_type, confidence, method = classify_with_fasttext(pdf_text, model) | |
| conf_percent = confidence * 100 | |
| inference_time = time.time() - start_time | |
| # Create probability distribution for FastText | |
| # For FastText, we only have one prediction, so we'll estimate others | |
| prob_dict = {} | |
| if doc_type == "Invoice": | |
| prob_dict["Invoice"] = conf_percent | |
| prob_dict["Contract"] = max(0, (100 - conf_percent) * 0.3) | |
| prob_dict["Other"] = max(0, (100 - conf_percent) * 0.7) | |
| elif doc_type == "Contract": | |
| prob_dict["Contract"] = conf_percent | |
| prob_dict["Invoice"] = max(0, (100 - conf_percent) * 0.3) | |
| prob_dict["Other"] = max(0, (100 - conf_percent) * 0.7) | |
| else: | |
| prob_dict["Other"] = conf_percent | |
| prob_dict["Invoice"] = max(0, (100 - conf_percent) * 0.4) | |
| prob_dict["Contract"] = max(0, (100 - conf_percent) * 0.6) | |
| # Normalize probabilities | |
| total = sum(prob_dict.values()) | |
| prob_dict = {k: (v / total * 100) if total > 0 else 33.33 for k, v in prob_dict.items()} | |
| else: | |
| # Transformer model classification | |
| inputs = tokenizer( | |
| pdf_text[:512], | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| inference_time = time.time() - start_time | |
| logits = outputs.logits | |
| probs = F.softmax(logits, dim=1) | |
| confidence, predicted_class = torch.max(probs, dim=1) | |
| doc_type = LABELS[predicted_class.item()] | |
| conf_percent = confidence.item() * 100 | |
| method = "transformer" | |
| # Create probability distribution for Transformer | |
| prob_dict = {LABELS[i]: probs[0][i].item() * 100 for i in range(len(LABELS))} | |
| # Results | |
| st.markdown("---") | |
| st.markdown('<div class="section-header"><h2 style="margin: 0; text-align: center;">🎯 Classification Results</h2></div>', unsafe_allow_html=True) | |
| # Result cards | |
| col1, col2, col3, col4 = st.columns(4, gap="medium") | |
| with col1: | |
| st.markdown(f""" | |
| <div class="result-card"> | |
| <h4>Model</h4> | |
| <p style="font-size: 1.25rem;">{selected_model_label.split('(')[0].strip()}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col2: | |
| st.markdown(f""" | |
| <div class="result-card"> | |
| <h4>Document Type</h4> | |
| <p>{doc_type}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col3: | |
| if conf_percent >= 80: | |
| conf_gradient = "linear-gradient(135deg, #22c55e 0%, #16a34a 100%)" | |
| elif conf_percent >= 60: | |
| conf_gradient = "linear-gradient(135deg, #f59e0b 0%, #d97706 100%)" | |
| else: | |
| conf_gradient = "linear-gradient(135deg, #ef4444 0%, #dc2626 100%)" | |
| st.markdown(f""" | |
| <div class="result-card" style="background: {conf_gradient};"> | |
| <h4>Confidence</h4> | |
| <p>{conf_percent:.1f}%</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col4: | |
| method_gradient = "linear-gradient(135deg, #8b5cf6 0%, #7c3aed 100%)" if method == "transformer" else \ | |
| "linear-gradient(135deg, #06b6d4 0%, #0891b2 100%)" if method == "fasttext" else \ | |
| "linear-gradient(135deg, #f59e0b 0%, #d97706 100%)" | |
| st.markdown(f""" | |
| <div class="result-card" style="background: {method_gradient};"> | |
| <h4>Method</h4> | |
| <p style="font-size: 1rem; text-transform: capitalize;">{method}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| # Inference time | |
| col_time1, col_time2, col_time3 = st.columns([1, 2, 1]) | |
| with col_time2: | |
| st.metric("⚡ Inference Time", f"{inference_time:.3f} seconds") | |
| # Confidence breakdown | |
| st.markdown('<div class="section-header"><h2 style="margin: 0; text-align: center;">📈 Confidence Breakdown</h2></div>', unsafe_allow_html=True) | |
| sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True) | |
| for label, prob in sorted_probs: | |
| if label == doc_type: | |
| bar_gradient = "linear-gradient(90deg, #3b82f6 0%, #2563eb 100%)" | |
| else: | |
| bar_gradient = "linear-gradient(90deg, #93c5fd 0%, #60a5fa 100%)" | |
| st.markdown(f""" | |
| <div class="progress-wrapper"> | |
| <div class="progress-label"> | |
| <span>{label}</span> | |
| <span style="font-weight: 700; color: #1e3a8a;">{prob:.1f}%</span> | |
| </div> | |
| <div class="progress-bar"> | |
| <div class="progress-fill" style="background: {bar_gradient}; width: {prob}%;"></div> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |