import os import io import re import html as _html import hashlib import time import streamlit as st import streamlit.components.v1 as components import torch import pytesseract import fitz # PyMuPDF from PIL import Image from transformers import AutoTokenizer, AutoModelForSequenceClassification from peft import PeftModel import torch.nn.functional as F # ========================= # STREAMLIT PAGE CONFIG (MUST BE FIRST STREAMLIT COMMAND) # ========================= def _configure_page(): """ Configure Streamlit page settings. This file can be executed directly (streamlit run streamlit_app.py) OR imported by `app.py` (Docker entrypoint). Streamlit only allows calling `st.set_page_config()` once per page, so we make this idempotent. """ try: st.set_page_config( page_title="OCR Document Classifier", page_icon="📄", layout="wide", ) except Exception: # If another module (e.g., `app.py`) already called set_page_config, # Streamlit raises StreamlitAPIException. Ignore to prevent a crash. pass _configure_page() # ========================= # TESSERACT CONFIG (WINDOWS) # ========================= if os.name == "nt": # Allow override via env var; keep a sensible Windows default. pytesseract.pytesseract.tesseract_cmd = os.getenv( "TESSERACT_CMD", r"C:\Program Files\Tesseract-OCR\tesseract.exe" ) os.environ.setdefault("TESSDATA_PREFIX", r"C:\Program Files\Tesseract-OCR\tessdata") # ========================= # MODEL PATHS & CONFIG # ========================= BASE_MODEL = "prajjwal1/bert-tiny" ADAPTER_PATH = "./lora_adapter" MAX_LENGTH = 256 LABELS = [ "Employment Letter", "Lease / Agreement", "Bank Statements", "Paystub / Payslip", "Property Tax", "Investment", "Tax Documents", "Other Documents", "ID / License", ] label2id = {label: i for i, label in enumerate(LABELS)} id2label = {i: label for label, i in label2id.items()} NUM_LABELS = len(LABELS) # ========================= # LOAD MODEL # ========================= @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH) base_model = AutoModelForSequenceClassification.from_pretrained( BASE_MODEL, num_labels=NUM_LABELS, id2label=id2label, label2id=label2id ) model = PeftModel.from_pretrained(base_model, ADAPTER_PATH) model.eval() return tokenizer, model tokenizer, model = load_model() # ========================= # OCR TEXT CLEANING # ========================= def clean_ocr_text(text): text = text.replace("\n", " ") text = re.sub(r"\s+", " ", text) text = re.sub(r"[^A-Za-z0-9.,$ ]", "", text) return text.strip() # ========================= # PDF → OCR # ========================= def extract_text_from_pdf(pdf_bytes, dpi=300, progress_cb=None): doc = fitz.open(stream=pdf_bytes, filetype="pdf") extracted_text = "" total_pages = doc.page_count or 0 for page_num, page in enumerate(doc): pix = page.get_pixmap(dpi=int(dpi)) img = Image.open(io.BytesIO(pix.tobytes("png"))) text = pytesseract.image_to_string(img, lang="eng") extracted_text += f"\n--- Page {page_num + 1} ---\n{text}" if progress_cb and total_pages > 0: progress_cb(page_num + 1, total_pages) return clean_ocr_text(extracted_text) # ========================= # PREDICTION # ========================= def predict(text, top_k=3, max_length=MAX_LENGTH): inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=int(max_length), ) with torch.no_grad(): outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=1) top_probs, top_indices = torch.topk(probs, k=top_k, dim=1) top_labels = [model.config.id2label[idx.item()] for idx in top_indices[0]] top_confidences = [p.item() * 100 for p in top_probs[0]] return list(zip(top_labels, top_confidences)) # ========================= # STREAMLIT UI # ========================= def _inject_ui_css(): st.markdown( """ """, unsafe_allow_html=True, ) def _hero(): # Streamlit 1.27.0 (pinned in Docker) does not support `vertical_alignment=...`. left, right = st.columns([1.35, 1]) with left: st.markdown( """
⚡ OCR + LoRA TinyBERT • Production UI
Document Classification, done fast.
Upload a scanned PDF (or paste text), extract OCR, and get the top predictions with confidence — in a clean, modern dashboard.
""", unsafe_allow_html=True, ) with right: # Lottie animation (works best with internet; safely degrades if blocked) components.html( """
""", height=290, ) def _render_predictions(predictions): rows = [] for label, conf in predictions: w = max(0.0, min(100.0, float(conf))) # IMPORTANT: no leading indentation here — Markdown treats indented HTML as a code block. safe_label = _html.escape(str(label)) rows.append( f'
' f'
{safe_label}
' f'
' f'
{w:.1f}%
' f"
" ) st.markdown(f"
{''.join(rows)}
", unsafe_allow_html=True) _inject_ui_css() with st.sidebar: st.markdown("### Controls") top_k = st.slider("Top-K predictions", min_value=1, max_value=5, value=3, help="How many labels to show.") ocr_dpi = st.slider( "OCR quality (DPI)", min_value=150, max_value=400, value=250, step=25, help="Higher DPI improves OCR but increases processing time.", ) preview_chars = st.slider( "Preview length", min_value=500, max_value=8000, value=3000, step=500, help="How much OCR text to show in the preview box.", ) st.markdown("---") st.caption("Tip: For low-quality scans, increase DPI and re-run OCR.") _hero() st.write("") tab_upload, tab_paste = st.tabs(["Upload PDF", "Paste Text"]) with tab_upload: st.markdown("
", unsafe_allow_html=True) st.markdown("#### Upload a scanned PDF") uploaded_file = st.file_uploader( "Upload PDF", type=["pdf"], label_visibility="collapsed", help="PDF should contain scanned pages or images. We'll OCR it, then classify.", ) st.markdown("
", unsafe_allow_html=True) if uploaded_file: # IMPORTANT: Streamlit reruns the script on every interaction. If we OCR inside this block, # clicking "Classify" would re-trigger OCR and feel like it's stuck. So we store OCR output # in session_state keyed by (file hash + DPI). pdf_bytes = uploaded_file.getvalue() file_hash = hashlib.sha256(pdf_bytes).hexdigest()[:16] ocr_key = f"{file_hash}:{int(ocr_dpi)}" if st.session_state.get("ocr_key") != ocr_key: st.session_state["ocr_key"] = ocr_key st.session_state["ocr_text"] = None st.session_state["ocr_seconds"] = None extracted_text = st.session_state.get("ocr_text") col_run, col_hint = st.columns([1, 2.2]) with col_run: run_ocr = st.button("Run OCR", use_container_width=True, key="run_ocr_btn") with col_hint: st.markdown( "
" "Tip
" "OCR is the slowest part. Run it once, then classify instantly. " "Lower DPI = faster OCR." "
", unsafe_allow_html=True, ) if run_ocr or (extracted_text is None and st.session_state.get("auto_ocr_once") is None): # Auto-run OCR once on first upload to keep UX smooth, but never re-run on button clicks. st.session_state["auto_ocr_once"] = True # `text=` was added to st.progress in later Streamlit versions; keep compatible with 1.27.0. prog = st.progress(0) prog_text = st.empty() prog_text.caption("Running OCR…") def _cb(done, total): pct = int((done / total) * 100) prog.progress(pct) prog_text.caption(f"Running OCR… {done}/{total} pages") t0 = time.time() with st.spinner("Extracting text with Tesseract…"): extracted_text = extract_text_from_pdf(pdf_bytes, dpi=ocr_dpi, progress_cb=_cb) st.session_state["ocr_text"] = extracted_text st.session_state["ocr_seconds"] = max(0.0, time.time() - t0) prog.empty() prog_text.empty() extracted_text = st.session_state.get("ocr_text") if extracted_text: secs = st.session_state.get("ocr_seconds") if secs is not None: st.caption(f"OCR completed in {secs:.1f}s • DPI {int(ocr_dpi)}") with st.expander("Extracted text preview", expanded=False): st.text_area("OCR Output", extracted_text[:preview_chars], height=260, label_visibility="collapsed") col_a, col_b = st.columns([1, 2.2]) with col_a: run = st.button("Classify document", use_container_width=True, key="classify_doc_btn") with col_b: st.markdown( "
" "What happens next?
" "We tokenize the OCR text and run your LoRA-adapted TinyBERT classifier." "
", unsafe_allow_html=True, ) if run: with st.spinner("Running model inference…"): predictions = predict(extracted_text, top_k=top_k, max_length=MAX_LENGTH) st.markdown("### Results") _render_predictions(predictions) else: st.info("Click **Run OCR** to extract text, then you can classify the document.") with tab_paste: st.markdown("
", unsafe_allow_html=True) st.markdown("#### Paste text (skip OCR)") pasted = st.text_area( "Paste text", placeholder="Paste document text here…", height=220, label_visibility="collapsed", ) col1, col2 = st.columns([1, 3]) with col1: run_text = st.button("Classify text", use_container_width=True) with col2: st.caption("Useful for already-digital documents, emails, or copied text.") st.markdown("
", unsafe_allow_html=True) if run_text and pasted.strip(): with st.spinner("Running model inference…"): predictions = predict(clean_ocr_text(pasted), top_k=top_k, max_length=MAX_LENGTH) st.markdown("### Results") _render_predictions(predictions)