Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import re | |
| import html as _html | |
| import hashlib | |
| import time | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import torch | |
| import pytesseract | |
| import fitz # PyMuPDF | |
| from PIL import Image | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from peft import PeftModel | |
| import torch.nn.functional as F | |
| # ========================= | |
| # STREAMLIT PAGE CONFIG (MUST BE FIRST STREAMLIT COMMAND) | |
| # ========================= | |
| def _configure_page(): | |
| """ | |
| Configure Streamlit page settings. | |
| This file can be executed directly (streamlit run streamlit_app.py) OR | |
| imported by `app.py` (Docker entrypoint). Streamlit only allows calling | |
| `st.set_page_config()` once per page, so we make this idempotent. | |
| """ | |
| try: | |
| st.set_page_config( | |
| page_title="OCR Document Classifier", | |
| page_icon="📄", | |
| layout="wide", | |
| ) | |
| except Exception: | |
| # If another module (e.g., `app.py`) already called set_page_config, | |
| # Streamlit raises StreamlitAPIException. Ignore to prevent a crash. | |
| pass | |
| _configure_page() | |
| # ========================= | |
| # TESSERACT CONFIG (WINDOWS) | |
| # ========================= | |
| if os.name == "nt": | |
| # Allow override via env var; keep a sensible Windows default. | |
| pytesseract.pytesseract.tesseract_cmd = os.getenv( | |
| "TESSERACT_CMD", r"C:\Program Files\Tesseract-OCR\tesseract.exe" | |
| ) | |
| os.environ.setdefault("TESSDATA_PREFIX", r"C:\Program Files\Tesseract-OCR\tessdata") | |
| # ========================= | |
| # MODEL PATHS & CONFIG | |
| # ========================= | |
| BASE_MODEL = "prajjwal1/bert-tiny" | |
| ADAPTER_PATH = "./lora_adapter" | |
| MAX_LENGTH = 256 | |
| LABELS = [ | |
| "Employment Letter", | |
| "Lease / Agreement", | |
| "Bank Statements", | |
| "Paystub / Payslip", | |
| "Property Tax", | |
| "Investment", | |
| "Tax Documents", | |
| "Other Documents", | |
| "ID / License", | |
| ] | |
| label2id = {label: i for i, label in enumerate(LABELS)} | |
| id2label = {i: label for label, i in label2id.items()} | |
| NUM_LABELS = len(LABELS) | |
| # ========================= | |
| # LOAD MODEL | |
| # ========================= | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH) | |
| base_model = AutoModelForSequenceClassification.from_pretrained( | |
| BASE_MODEL, | |
| num_labels=NUM_LABELS, | |
| id2label=id2label, | |
| label2id=label2id | |
| ) | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_PATH) | |
| model.eval() | |
| return tokenizer, model | |
| tokenizer, model = load_model() | |
| # ========================= | |
| # OCR TEXT CLEANING | |
| # ========================= | |
| def clean_ocr_text(text): | |
| text = text.replace("\n", " ") | |
| text = re.sub(r"\s+", " ", text) | |
| text = re.sub(r"[^A-Za-z0-9.,$ ]", "", text) | |
| return text.strip() | |
| # ========================= | |
| # PDF → OCR | |
| # ========================= | |
| def extract_text_from_pdf(pdf_bytes, dpi=300, progress_cb=None): | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| extracted_text = "" | |
| total_pages = doc.page_count or 0 | |
| for page_num, page in enumerate(doc): | |
| pix = page.get_pixmap(dpi=int(dpi)) | |
| img = Image.open(io.BytesIO(pix.tobytes("png"))) | |
| text = pytesseract.image_to_string(img, lang="eng") | |
| extracted_text += f"\n--- Page {page_num + 1} ---\n{text}" | |
| if progress_cb and total_pages > 0: | |
| progress_cb(page_num + 1, total_pages) | |
| return clean_ocr_text(extracted_text) | |
| # ========================= | |
| # PREDICTION | |
| # ========================= | |
| def predict(text, top_k=3, max_length=MAX_LENGTH): | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=int(max_length), | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=1) | |
| top_probs, top_indices = torch.topk(probs, k=top_k, dim=1) | |
| top_labels = [model.config.id2label[idx.item()] for idx in top_indices[0]] | |
| top_confidences = [p.item() * 100 for p in top_probs[0]] | |
| return list(zip(top_labels, top_confidences)) | |
| # ========================= | |
| # STREAMLIT UI | |
| # ========================= | |
| def _inject_ui_css(): | |
| st.markdown( | |
| """ | |
| <style> | |
| /* ---- App canvas ---- */ | |
| .stApp { | |
| background: | |
| radial-gradient(1200px 600px at 10% 10%, rgba(124,58,237,0.25), transparent 55%), | |
| radial-gradient(900px 500px at 90% 15%, rgba(34,197,94,0.18), transparent 60%), | |
| radial-gradient(900px 700px at 50% 90%, rgba(59,130,246,0.16), transparent 55%), | |
| linear-gradient(180deg, #070B18 0%, #0B1020 45%, #070B18 100%); | |
| color: #E5E7EB; | |
| font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Apple Color Emoji", "Segoe UI Emoji"; | |
| } | |
| /* Make the content width feel 'product' */ | |
| section.main > div.block-container { padding-top: 1.25rem; padding-bottom: 2.5rem; max-width: 1100px; } | |
| /* ---- Glass cards ---- */ | |
| .glass { | |
| background: rgba(255,255,255,0.06); | |
| border: 1px solid rgba(255,255,255,0.10); | |
| box-shadow: 0 10px 30px rgba(0,0,0,0.35); | |
| border-radius: 18px; | |
| padding: 18px 18px; | |
| backdrop-filter: blur(10px); | |
| } | |
| .hero-title { | |
| font-size: 2.25rem; | |
| line-height: 1.15; | |
| font-weight: 800; | |
| letter-spacing: -0.02em; | |
| margin: 0 0 0.35rem 0; | |
| } | |
| .hero-sub { | |
| color: rgba(229,231,235,0.78); | |
| font-size: 1.02rem; | |
| margin: 0 0 0.85rem 0; | |
| } | |
| .badge { | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 8px; | |
| padding: 6px 10px; | |
| border-radius: 999px; | |
| background: rgba(124,58,237,0.20); | |
| border: 1px solid rgba(124,58,237,0.35); | |
| color: rgba(243,244,246,0.95); | |
| font-size: 0.85rem; | |
| margin-bottom: 10px; | |
| } | |
| /* Subtle entrance animation */ | |
| @keyframes fadeUp { | |
| 0% { opacity: 0; transform: translateY(10px); } | |
| 100% { opacity: 1; transform: translateY(0); } | |
| } | |
| .fade-up { animation: fadeUp 520ms ease-out both; } | |
| /* ---- File uploader dropzone ---- */ | |
| [data-testid="stFileUploaderDropzone"] { | |
| border: 1px dashed rgba(255,255,255,0.25) !important; | |
| background: rgba(255,255,255,0.04) !important; | |
| border-radius: 18px !important; | |
| padding: 22px !important; | |
| } | |
| [data-testid="stFileUploaderDropzone"]:hover { | |
| border-color: rgba(124,58,237,0.7) !important; | |
| box-shadow: 0 0 0 4px rgba(124,58,237,0.18); | |
| } | |
| /* ---- Buttons ---- */ | |
| .stButton > button { | |
| border: 1px solid rgba(255,255,255,0.14); | |
| background: linear-gradient(135deg, rgba(124,58,237,0.95), rgba(59,130,246,0.88)); | |
| color: white; | |
| border-radius: 14px; | |
| padding: 0.7rem 1rem; | |
| font-weight: 700; | |
| transition: transform 120ms ease, filter 120ms ease, box-shadow 120ms ease; | |
| box-shadow: 0 10px 22px rgba(0,0,0,0.35); | |
| } | |
| .stButton > button:hover { | |
| transform: translateY(-1px); | |
| filter: brightness(1.05); | |
| box-shadow: 0 14px 26px rgba(0,0,0,0.45); | |
| } | |
| /* ---- Text areas & inputs ---- */ | |
| textarea, input { | |
| border-radius: 14px !important; | |
| } | |
| /* ---- Animated results bars ---- */ | |
| .result-row { | |
| display: grid; | |
| grid-template-columns: 140px 1fr 70px; | |
| align-items: center; | |
| gap: 12px; | |
| padding: 10px 12px; | |
| border-radius: 14px; | |
| background: rgba(255,255,255,0.04); | |
| border: 1px solid rgba(255,255,255,0.08); | |
| margin: 10px 0; | |
| } | |
| .result-label { font-weight: 700; color: rgba(243,244,246,0.95); } | |
| .bar-wrap { | |
| height: 12px; | |
| border-radius: 999px; | |
| background: rgba(255,255,255,0.08); | |
| overflow: hidden; | |
| border: 1px solid rgba(255,255,255,0.10); | |
| } | |
| .bar { | |
| height: 100%; | |
| width: 0%; | |
| border-radius: 999px; | |
| background: linear-gradient(90deg, rgba(34,197,94,0.95), rgba(124,58,237,0.95), rgba(59,130,246,0.95)); | |
| animation: grow 900ms cubic-bezier(.2,.8,.2,1) forwards; | |
| } | |
| @keyframes grow { to { width: var(--w); } } | |
| .result-pct { font-variant-numeric: tabular-nums; color: rgba(229,231,235,0.82); text-align: right; } | |
| /* ---- Footer ---- */ | |
| .footer { | |
| margin-top: 28px; | |
| color: rgba(229,231,235,0.55); | |
| font-size: 0.9rem; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def _hero(): | |
| # Streamlit 1.27.0 (pinned in Docker) does not support `vertical_alignment=...`. | |
| left, right = st.columns([1.35, 1]) | |
| with left: | |
| st.markdown( | |
| """ | |
| <div class="glass fade-up"> | |
| <div class="badge">⚡ OCR + LoRA TinyBERT • Production UI</div> | |
| <div class="hero-title">Document Classification, done fast.</div> | |
| <div class="hero-sub"> | |
| Upload a scanned PDF (or paste text), extract OCR, and get the top predictions with confidence — | |
| in a clean, modern dashboard. | |
| </div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| with right: | |
| # Lottie animation (works best with internet; safely degrades if blocked) | |
| components.html( | |
| """ | |
| <div class="glass fade-up" style="padding: 8px 10px;"> | |
| <script src="https://unpkg.com/@lottiefiles/lottie-player@latest/dist/lottie-player.js"></script> | |
| <lottie-player | |
| src="https://assets10.lottiefiles.com/packages/lf20_1LhsaB.json" | |
| background="transparent" | |
| speed="1" | |
| style="width: 100%; height: 260px;" | |
| loop | |
| autoplay> | |
| </lottie-player> | |
| </div> | |
| """, | |
| height=290, | |
| ) | |
| def _render_predictions(predictions): | |
| rows = [] | |
| for label, conf in predictions: | |
| w = max(0.0, min(100.0, float(conf))) | |
| # IMPORTANT: no leading indentation here — Markdown treats indented HTML as a code block. | |
| safe_label = _html.escape(str(label)) | |
| rows.append( | |
| f'<div class="result-row">' | |
| f'<div class="result-label">{safe_label}</div>' | |
| f'<div class="bar-wrap"><div class="bar" style="--w: {w:.2f}%;"></div></div>' | |
| f'<div class="result-pct">{w:.1f}%</div>' | |
| f"</div>" | |
| ) | |
| st.markdown(f"<div class='glass fade-up'>{''.join(rows)}</div>", unsafe_allow_html=True) | |
| _inject_ui_css() | |
| with st.sidebar: | |
| st.markdown("### Controls") | |
| top_k = st.slider("Top-K predictions", min_value=1, max_value=5, value=3, help="How many labels to show.") | |
| ocr_dpi = st.slider( | |
| "OCR quality (DPI)", | |
| min_value=150, | |
| max_value=400, | |
| value=250, | |
| step=25, | |
| help="Higher DPI improves OCR but increases processing time.", | |
| ) | |
| preview_chars = st.slider( | |
| "Preview length", | |
| min_value=500, | |
| max_value=8000, | |
| value=3000, | |
| step=500, | |
| help="How much OCR text to show in the preview box.", | |
| ) | |
| st.markdown("---") | |
| st.caption("Tip: For low-quality scans, increase DPI and re-run OCR.") | |
| _hero() | |
| st.write("") | |
| tab_upload, tab_paste = st.tabs(["Upload PDF", "Paste Text"]) | |
| with tab_upload: | |
| st.markdown("<div class='glass fade-up'>", unsafe_allow_html=True) | |
| st.markdown("#### Upload a scanned PDF") | |
| uploaded_file = st.file_uploader( | |
| "Upload PDF", | |
| type=["pdf"], | |
| label_visibility="collapsed", | |
| help="PDF should contain scanned pages or images. We'll OCR it, then classify.", | |
| ) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| if uploaded_file: | |
| # IMPORTANT: Streamlit reruns the script on every interaction. If we OCR inside this block, | |
| # clicking "Classify" would re-trigger OCR and feel like it's stuck. So we store OCR output | |
| # in session_state keyed by (file hash + DPI). | |
| pdf_bytes = uploaded_file.getvalue() | |
| file_hash = hashlib.sha256(pdf_bytes).hexdigest()[:16] | |
| ocr_key = f"{file_hash}:{int(ocr_dpi)}" | |
| if st.session_state.get("ocr_key") != ocr_key: | |
| st.session_state["ocr_key"] = ocr_key | |
| st.session_state["ocr_text"] = None | |
| st.session_state["ocr_seconds"] = None | |
| extracted_text = st.session_state.get("ocr_text") | |
| col_run, col_hint = st.columns([1, 2.2]) | |
| with col_run: | |
| run_ocr = st.button("Run OCR", use_container_width=True, key="run_ocr_btn") | |
| with col_hint: | |
| st.markdown( | |
| "<div class='glass fade-up' style='padding: 14px 16px;'>" | |
| "<b>Tip</b><br/>" | |
| "OCR is the slowest part. Run it once, then classify instantly. " | |
| "Lower DPI = faster OCR." | |
| "</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| if run_ocr or (extracted_text is None and st.session_state.get("auto_ocr_once") is None): | |
| # Auto-run OCR once on first upload to keep UX smooth, but never re-run on button clicks. | |
| st.session_state["auto_ocr_once"] = True | |
| # `text=` was added to st.progress in later Streamlit versions; keep compatible with 1.27.0. | |
| prog = st.progress(0) | |
| prog_text = st.empty() | |
| prog_text.caption("Running OCR…") | |
| def _cb(done, total): | |
| pct = int((done / total) * 100) | |
| prog.progress(pct) | |
| prog_text.caption(f"Running OCR… {done}/{total} pages") | |
| t0 = time.time() | |
| with st.spinner("Extracting text with Tesseract…"): | |
| extracted_text = extract_text_from_pdf(pdf_bytes, dpi=ocr_dpi, progress_cb=_cb) | |
| st.session_state["ocr_text"] = extracted_text | |
| st.session_state["ocr_seconds"] = max(0.0, time.time() - t0) | |
| prog.empty() | |
| prog_text.empty() | |
| extracted_text = st.session_state.get("ocr_text") | |
| if extracted_text: | |
| secs = st.session_state.get("ocr_seconds") | |
| if secs is not None: | |
| st.caption(f"OCR completed in {secs:.1f}s • DPI {int(ocr_dpi)}") | |
| with st.expander("Extracted text preview", expanded=False): | |
| st.text_area("OCR Output", extracted_text[:preview_chars], height=260, label_visibility="collapsed") | |
| col_a, col_b = st.columns([1, 2.2]) | |
| with col_a: | |
| run = st.button("Classify document", use_container_width=True, key="classify_doc_btn") | |
| with col_b: | |
| st.markdown( | |
| "<div class='glass fade-up' style='padding: 14px 16px;'>" | |
| "<b>What happens next?</b><br/>" | |
| "We tokenize the OCR text and run your LoRA-adapted TinyBERT classifier." | |
| "</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| if run: | |
| with st.spinner("Running model inference…"): | |
| predictions = predict(extracted_text, top_k=top_k, max_length=MAX_LENGTH) | |
| st.markdown("### Results") | |
| _render_predictions(predictions) | |
| else: | |
| st.info("Click **Run OCR** to extract text, then you can classify the document.") | |
| with tab_paste: | |
| st.markdown("<div class='glass fade-up'>", unsafe_allow_html=True) | |
| st.markdown("#### Paste text (skip OCR)") | |
| pasted = st.text_area( | |
| "Paste text", | |
| placeholder="Paste document text here…", | |
| height=220, | |
| label_visibility="collapsed", | |
| ) | |
| col1, col2 = st.columns([1, 3]) | |
| with col1: | |
| run_text = st.button("Classify text", use_container_width=True) | |
| with col2: | |
| st.caption("Useful for already-digital documents, emails, or copied text.") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| if run_text and pasted.strip(): | |
| with st.spinner("Running model inference…"): | |
| predictions = predict(clean_ocr_text(pasted), top_k=top_k, max_length=MAX_LENGTH) | |
| st.markdown("### Results") | |
| _render_predictions(predictions) | |