Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import fitz # PyMuPDF | |
| import torch | |
| import time | |
| import torch.nn.functional as F | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification | |
| ) | |
| # ------------------------------- | |
| # Page Configuration | |
| # ------------------------------- | |
| st.set_page_config( | |
| page_title="PDF Document Classification", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| # ------------------------------- | |
| # Beautiful Blue Theme CSS | |
| # ------------------------------- | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;500;600;700;800&display=swap'); | |
| * { | |
| font-family: 'Poppins', sans-serif; | |
| } | |
| .stApp { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 25%, #667eea 50%, #4facfe 75%, #00f2fe 100%); | |
| background-size: 400% 400%; | |
| animation: gradientShift 15s ease infinite; | |
| } | |
| @keyframes gradientShift { | |
| 0% { background-position: 0% 50%; } | |
| 50% { background-position: 100% 50%; } | |
| 100% { background-position: 0% 50%; } | |
| } | |
| .main .block-container { | |
| max-width: 1100px; | |
| padding-top: 2rem; | |
| padding-bottom: 3rem; | |
| } | |
| /* Header Styling */ | |
| h1 { | |
| font-size: 3rem !important; | |
| font-weight: 800 !important; | |
| color: #ffffff !important; | |
| text-align: center; | |
| margin-bottom: 0.5rem !important; | |
| text-shadow: 2px 2px 8px rgba(0,0,0,0.2); | |
| letter-spacing: -0.02em; | |
| } | |
| h2 { | |
| font-size: 1.75rem !important; | |
| font-weight: 700 !important; | |
| color: #1e3a8a !important; | |
| margin-top: 2rem !important; | |
| margin-bottom: 1rem !important; | |
| } | |
| h3 { | |
| font-size: 1.25rem !important; | |
| font-weight: 600 !important; | |
| color: #1e40af !important; | |
| margin-bottom: 0.75rem !important; | |
| } | |
| /* Card Styling */ | |
| .blue-card { | |
| background: linear-gradient(135deg, #ffffff 0%, #f0f9ff 100%); | |
| border-radius: 16px; | |
| padding: 2rem; | |
| box-shadow: 0 10px 30px rgba(30, 58, 138, 0.2); | |
| border: 2px solid rgba(59, 130, 246, 0.3); | |
| margin-bottom: 1.5rem; | |
| backdrop-filter: blur(10px); | |
| transition: all 0.3s ease; | |
| } | |
| .blue-card:hover { | |
| transform: translateY(-5px); | |
| box-shadow: 0 15px 40px rgba(30, 58, 138, 0.3); | |
| } | |
| /* Selectbox Styling */ | |
| .stSelectbox > div > div { | |
| background: linear-gradient(135deg, #ffffff 0%, #f8fafc 100%) !important; | |
| border: 2px solid #3b82f6 !important; | |
| border-radius: 12px !important; | |
| color: #1e3a8a !important; | |
| } | |
| .stSelectbox label { | |
| font-weight: 600 !important; | |
| color: #1e40af !important; | |
| font-size: 0.9375rem !important; | |
| margin-bottom: 0.75rem !important; | |
| } | |
| .stSelectbox [data-baseweb="select"] { | |
| color: #1e3a8a !important; | |
| font-weight: 500 !important; | |
| } | |
| /* File Uploader Styling */ | |
| .stFileUploader { | |
| border: 3px dashed #3b82f6 !important; | |
| border-radius: 16px !important; | |
| padding: 2rem !important; | |
| background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(147, 197, 253, 0.1) 100%) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .stFileUploader:hover { | |
| border-color: #2563eb !important; | |
| background: linear-gradient(135deg, rgba(59, 130, 246, 0.2) 0%, rgba(147, 197, 253, 0.2) 100%) !important; | |
| transform: scale(1.02); | |
| } | |
| .stFileUploader label { | |
| font-weight: 600 !important; | |
| color: #1e40af !important; | |
| font-size: 0.9375rem !important; | |
| margin-bottom: 0.75rem !important; | |
| } | |
| /* Button Styling */ | |
| .stButton > button { | |
| background: linear-gradient(135deg, #3b82f6 0%, #2563eb 50%, #1d4ed8 100%) !important; | |
| color: white !important; | |
| font-weight: 700 !important; | |
| font-size: 1rem !important; | |
| padding: 1rem 2rem !important; | |
| border-radius: 12px !important; | |
| border: none !important; | |
| width: 100% !important; | |
| transition: all 0.3s ease !important; | |
| box-shadow: 0 8px 20px rgba(59, 130, 246, 0.4) !important; | |
| text-transform: uppercase; | |
| letter-spacing: 0.5px; | |
| } | |
| .stButton > button:hover { | |
| background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 50%, #1e40af 100%) !important; | |
| box-shadow: 0 12px 30px rgba(59, 130, 246, 0.6) !important; | |
| transform: translateY(-3px) scale(1.02) !important; | |
| } | |
| /* Text Area Styling */ | |
| .stTextArea textarea { | |
| font-family: 'Monaco', 'Menlo', monospace !important; | |
| font-size: 0.875rem !important; | |
| background: #ffffff !important; | |
| border: 2px solid #3b82f6 !important; | |
| border-radius: 12px !important; | |
| color: #1e3a8a !important; | |
| padding: 1rem !important; | |
| } | |
| .stTextArea textarea:disabled { | |
| background: #f8fafc !important; | |
| color: #1e3a8a !important; | |
| -webkit-text-fill-color: #1e3a8a !important; | |
| opacity: 1 !important; | |
| } | |
| .stTextArea textarea::placeholder { | |
| color: #94a3b8 !important; | |
| } | |
| .stTextArea label { | |
| font-weight: 600 !important; | |
| color: #1e40af !important; | |
| font-size: 0.9375rem !important; | |
| } | |
| /* Ensure textarea content is visible */ | |
| textarea[disabled] { | |
| color: #1e3a8a !important; | |
| -webkit-text-fill-color: #1e3a8a !important; | |
| } | |
| /* Force text color in textarea */ | |
| .stTextArea textarea, | |
| .stTextArea textarea[disabled], | |
| .stTextArea textarea:disabled { | |
| color: #1e3a8a !important; | |
| -webkit-text-fill-color: #1e3a8a !important; | |
| } | |
| /* Additional textarea visibility fixes */ | |
| textarea { | |
| color: #1e3a8a !important; | |
| } | |
| textarea[disabled] { | |
| color: #1e3a8a !important; | |
| -webkit-text-fill-color: #1e3a8a !important; | |
| opacity: 1 !important; | |
| } | |
| /* Target Streamlit's textarea wrapper */ | |
| [data-testid="stTextArea"] textarea, | |
| [data-testid="stTextArea"] textarea[disabled] { | |
| color: #1e3a8a !important; | |
| -webkit-text-fill-color: #1e3a8a !important; | |
| } | |
| /* Ensure text content is visible */ | |
| .stTextArea textarea::value, | |
| .stTextArea textarea::content { | |
| color: #1e3a8a !important; | |
| } | |
| /* Metric Styling */ | |
| .stMetric { | |
| background: linear-gradient(135deg, #ffffff 0%, #eff6ff 100%) !important; | |
| padding: 1.5rem !important; | |
| border-radius: 16px !important; | |
| border: 2px solid rgba(59, 130, 246, 0.3) !important; | |
| box-shadow: 0 8px 20px rgba(30, 58, 138, 0.15) !important; | |
| } | |
| .stMetric label { | |
| font-weight: 600 !important; | |
| color: #3b82f6 !important; | |
| font-size: 0.8125rem !important; | |
| text-transform: uppercase; | |
| letter-spacing: 0.1em; | |
| } | |
| .stMetric [data-testid="stMetricValue"] { | |
| font-weight: 800 !important; | |
| color: #1e3a8a !important; | |
| font-size: 2rem !important; | |
| margin-top: 0.5rem !important; | |
| } | |
| /* Success/Error Messages */ | |
| .stSuccess { | |
| background: linear-gradient(135deg, rgba(34, 197, 94, 0.15) 0%, rgba(16, 185, 129, 0.15) 100%) !important; | |
| border-left: 5px solid #22c55e !important; | |
| border-radius: 12px !important; | |
| padding: 1.25rem !important; | |
| color: #065f46 !important; | |
| font-weight: 500 !important; | |
| } | |
| .stError { | |
| background: linear-gradient(135deg, rgba(239, 68, 68, 0.15) 0%, rgba(220, 38, 38, 0.15) 100%) !important; | |
| border-left: 5px solid #ef4444 !important; | |
| border-radius: 12px !important; | |
| padding: 1.25rem !important; | |
| color: #991b1b !important; | |
| font-weight: 500 !important; | |
| } | |
| /* Result Cards */ | |
| .result-card { | |
| background: linear-gradient(135deg, #3b82f6 0%, #2563eb 50%, #1d4ed8 100%); | |
| color: white; | |
| padding: 2rem; | |
| border-radius: 16px; | |
| text-align: center; | |
| box-shadow: 0 10px 30px rgba(30, 58, 138, 0.4); | |
| transition: all 0.3s ease; | |
| border: 2px solid rgba(255, 255, 255, 0.2); | |
| } | |
| .result-card:hover { | |
| transform: translateY(-5px) scale(1.05); | |
| box-shadow: 0 15px 40px rgba(30, 58, 138, 0.5); | |
| } | |
| .result-card h4 { | |
| font-size: 0.75rem; | |
| font-weight: 700; | |
| text-transform: uppercase; | |
| letter-spacing: 0.15em; | |
| opacity: 0.95; | |
| margin-bottom: 0.75rem; | |
| } | |
| .result-card p { | |
| font-size: 1.75rem; | |
| font-weight: 800; | |
| margin: 0; | |
| text-shadow: 1px 1px 3px rgba(0,0,0,0.2); | |
| } | |
| /* Progress Bars */ | |
| .progress-wrapper { | |
| margin-bottom: 1.5rem; | |
| background: linear-gradient(135deg, #ffffff 0%, #f0f9ff 100%); | |
| padding: 1.25rem; | |
| border-radius: 12px; | |
| border: 2px solid rgba(59, 130, 246, 0.2); | |
| box-shadow: 0 4px 10px rgba(30, 58, 138, 0.1); | |
| } | |
| .progress-label { | |
| display: flex; | |
| justify-content: space-between; | |
| margin-bottom: 0.75rem; | |
| font-size: 0.9375rem; | |
| font-weight: 600; | |
| color: #1e40af; | |
| } | |
| .progress-bar { | |
| height: 12px; | |
| background: #e0e7ff; | |
| border-radius: 8px; | |
| overflow: hidden; | |
| box-shadow: inset 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .progress-fill { | |
| height: 100%; | |
| border-radius: 8px; | |
| transition: width 0.5s ease; | |
| box-shadow: 0 2px 8px rgba(59, 130, 246, 0.4); | |
| } | |
| /* Divider */ | |
| hr { | |
| border: none; | |
| border-top: 3px solid rgba(255, 255, 255, 0.3); | |
| margin: 2.5rem 0; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| /* Spinner */ | |
| .stSpinner > div { | |
| border-color: #3b82f6; | |
| } | |
| /* Text Colors - Ensure text in white containers is dark, not white */ | |
| .blue-card, .section-header, .progress-wrapper, .stMetric, | |
| .blue-card p, .section-header p, .progress-wrapper p, .stMetric p, | |
| .blue-card span, .section-header span, .progress-wrapper span, .stMetric span, | |
| .blue-card div, .section-header div, .progress-wrapper div, .stMetric div, | |
| .blue-card *, .section-header *, .progress-wrapper *, .stMetric * { | |
| color: #1e3a8a !important; | |
| } | |
| /* Text on gradient background (not in white containers) should be white */ | |
| .main p:not(.blue-card p):not(.section-header p):not(.progress-wrapper p):not(.stMetric p), | |
| .main span:not(.blue-card span):not(.section-header span):not(.progress-wrapper span):not(.stMetric span) { | |
| color: #ffffff !important; | |
| } | |
| /* Caption text should be visible */ | |
| .main [data-testid="stCaption"] { | |
| color: rgba(255, 255, 255, 0.8) !important; | |
| } | |
| /* Strong text on gradient background */ | |
| .main strong:not(.blue-card strong):not(.section-header strong):not(.progress-wrapper strong):not(.stMetric strong), | |
| .main b:not(.blue-card b):not(.section-header b):not(.progress-wrapper b):not(.stMetric b) { | |
| color: #ffffff !important; | |
| font-weight: 700 !important; | |
| text-shadow: 1px 1px 3px rgba(0,0,0,0.2); | |
| } | |
| /* Override for any nested elements in white containers */ | |
| .blue-card strong, .section-header strong, .progress-wrapper strong, .stMetric strong, | |
| .blue-card b, .section-header b, .progress-wrapper b, .stMetric b { | |
| color: #1e3a8a !important; | |
| } | |
| /* Section Headers */ | |
| .section-header { | |
| background: linear-gradient(135deg, rgba(255, 255, 255, 0.95) 0%, rgba(240, 249, 255, 0.95) 100%); | |
| padding: 1.5rem; | |
| border-radius: 12px; | |
| border: 2px solid rgba(59, 130, 246, 0.3); | |
| margin-bottom: 1rem; | |
| box-shadow: 0 4px 15px rgba(30, 58, 138, 0.15); | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ------------------------------- | |
| # Header | |
| # ------------------------------- | |
| st.markdown(""" | |
| <div style="text-align: center; margin-bottom: 3rem;"> | |
| <h1>π PDF Document Classification</h1> | |
| <p style="color: rgba(255,255,255,0.95) !important; font-size: 1.125rem !important; font-weight: 400 !important; margin-top: 0.5rem; text-shadow: 1px 1px 4px rgba(0,0,0,0.2);"> | |
| <strong style="font-weight: 600;">AI-Powered Document Type Detection</strong> β’ Upload a text-based PDF to classify it as Invoice, Contract, or Other | |
| </p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # ------------------------------- | |
| # Model Configuration | |
| # ------------------------------- | |
| MODEL_OPTIONS = { | |
| "DistilBERT (distilbert-base-uncased)": "distilbert-base-uncased", | |
| "TinyBERT (huawei-noah/TinyBERT_General_6L_768D)": "huawei-noah/TinyBERT_General_6L_768D" | |
| } | |
| LABELS = { | |
| 0: "Invoice", | |
| 1: "Contract", | |
| 2: "Other" | |
| } | |
| NUM_LABELS = len(LABELS) | |
| # ------------------------------- | |
| # Load Model & Tokenizer | |
| # ------------------------------- | |
| def load_model(model_name): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| num_labels=NUM_LABELS | |
| ) | |
| model.eval() | |
| return tokenizer, model | |
| # ------------------------------- | |
| # PDF Text Extraction | |
| # ------------------------------- | |
| def extract_text_from_pdf(pdf_file): | |
| doc = fitz.open(stream=pdf_file.read(), filetype="pdf") | |
| text = "" | |
| for page in doc: | |
| text += page.get_text("text") | |
| return text.strip() | |
| # ------------------------------- | |
| # Main UI | |
| # ------------------------------- | |
| col1, col2 = st.columns(2, gap="large") | |
| with col1: | |
| st.markdown('<div class="section-header"><h3 style="margin: 0; color: #1e40af !important;">π€ Model Selection</h3></div>', unsafe_allow_html=True) | |
| selected_model_label = st.selectbox( | |
| "Choose your AI model", | |
| list(MODEL_OPTIONS.keys()), | |
| label_visibility="visible" | |
| ) | |
| selected_model_name = MODEL_OPTIONS[selected_model_label] | |
| tokenizer, model = load_model(selected_model_name) | |
| with col2: | |
| st.markdown('<div class="section-header"><h3 style="margin: 0; color: #1e40af !important;">π Document Upload</h3></div>', unsafe_allow_html=True) | |
| uploaded_file = st.file_uploader( | |
| "Upload your PDF file", | |
| type=["pdf"], | |
| label_visibility="visible", | |
| help="Select a text-based PDF file to classify" | |
| ) | |
| # ------------------------------- | |
| # Processing | |
| # ------------------------------- | |
| if uploaded_file: | |
| st.markdown("---") | |
| with st.spinner("π Extracting text from PDF..."): | |
| pdf_text = extract_text_from_pdf(uploaded_file) | |
| if not pdf_text: | |
| st.error("**β No Text Found**\n\nThis PDF appears to be image-based. Please upload a text-based PDF file.") | |
| else: | |
| st.success("**β Text Extracted Successfully!**\n\nReady for classification") | |
| st.markdown('<div class="section-header"><h3 style="margin: 0; color: #1e40af !important;">π Text Preview</h3></div>', unsafe_allow_html=True) | |
| # Display text in a styled container for better visibility | |
| preview_text = pdf_text[:2000] if len(pdf_text) > 2000 else pdf_text | |
| st.markdown(f""" | |
| <div style="background: #ffffff; border: 2px solid #3b82f6; border-radius: 12px; padding: 1.5rem; max-height: 300px; overflow-y: auto; font-family: 'Monaco', 'Menlo', monospace; font-size: 0.875rem; color: #1e3a8a; line-height: 1.6;"> | |
| {preview_text.replace(chr(10), '<br>').replace(chr(13), '')} | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.caption(f"Showing first 2000 characters of {len(pdf_text)} total characters") | |
| st.markdown("---") | |
| col_btn1, col_btn2, col_btn3 = st.columns([1, 2, 1]) | |
| with col_btn2: | |
| classify_btn = st.button("π Classify Document", use_container_width=True) | |
| if classify_btn: | |
| with st.spinner("π€ Running AI inference..."): | |
| start_time = time.time() | |
| inputs = tokenizer( | |
| pdf_text[:512], | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| inference_time = time.time() - start_time | |
| logits = outputs.logits | |
| probs = F.softmax(logits, dim=1) | |
| confidence, predicted_class = torch.max(probs, dim=1) | |
| # Results | |
| st.markdown("---") | |
| st.markdown('<div class="section-header"><h2 style="margin: 0; text-align: center;">π― Classification Results</h2></div>', unsafe_allow_html=True) | |
| doc_type = LABELS[predicted_class.item()] | |
| conf_percent = confidence.item() * 100 | |
| # Result cards | |
| col1, col2, col3 = st.columns(3, gap="medium") | |
| with col1: | |
| st.markdown(f""" | |
| <div class="result-card"> | |
| <h4>Model</h4> | |
| <p style="font-size: 1.25rem;">{selected_model_label.split('(')[0].strip()}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col2: | |
| st.markdown(f""" | |
| <div class="result-card"> | |
| <h4>Document Type</h4> | |
| <p>{doc_type}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col3: | |
| if conf_percent >= 80: | |
| conf_gradient = "linear-gradient(135deg, #22c55e 0%, #16a34a 100%)" | |
| elif conf_percent >= 60: | |
| conf_gradient = "linear-gradient(135deg, #f59e0b 0%, #d97706 100%)" | |
| else: | |
| conf_gradient = "linear-gradient(135deg, #ef4444 0%, #dc2626 100%)" | |
| st.markdown(f""" | |
| <div class="result-card" style="background: {conf_gradient};"> | |
| <h4>Confidence</h4> | |
| <p>{conf_percent:.1f}%</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| # Inference time | |
| col_time1, col_time2, col_time3 = st.columns([1, 2, 1]) | |
| with col_time2: | |
| st.metric("β‘ Inference Time", f"{inference_time:.3f} seconds") | |
| # Confidence breakdown | |
| st.markdown('<div class="section-header"><h2 style="margin: 0; text-align: center;">π Confidence Breakdown</h2></div>', unsafe_allow_html=True) | |
| prob_dict = {LABELS[i]: probs[0][i].item() * 100 for i in range(len(LABELS))} | |
| sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True) | |
| for label, prob in sorted_probs: | |
| if label == doc_type: | |
| bar_gradient = "linear-gradient(90deg, #3b82f6 0%, #2563eb 100%)" | |
| else: | |
| bar_gradient = "linear-gradient(90deg, #93c5fd 0%, #60a5fa 100%)" | |
| st.markdown(f""" | |
| <div class="progress-wrapper"> | |
| <div class="progress-label"> | |
| <span>{label}</span> | |
| <span style="font-weight: 700; color: #1e3a8a;">{prob:.1f}%</span> | |
| </div> | |
| <div class="progress-bar"> | |
| <div class="progress-fill" style="background: {bar_gradient}; width: {prob}%;"></div> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |