import os
import io
import re
import html as _html
import hashlib
import time
import streamlit as st
import streamlit.components.v1 as components
import torch
import pytesseract
import fitz # PyMuPDF
from PIL import Image
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
import torch.nn.functional as F
# =========================
# STREAMLIT PAGE CONFIG (MUST BE FIRST STREAMLIT COMMAND)
# =========================
def _configure_page():
"""
Configure Streamlit page settings.
This file can be executed directly (streamlit run streamlit_app.py) OR
imported by `app.py` (Docker entrypoint). Streamlit only allows calling
`st.set_page_config()` once per page, so we make this idempotent.
"""
try:
st.set_page_config(
page_title="OCR Document Classifier",
page_icon="📄",
layout="wide",
)
except Exception:
# If another module (e.g., `app.py`) already called set_page_config,
# Streamlit raises StreamlitAPIException. Ignore to prevent a crash.
pass
_configure_page()
# =========================
# TESSERACT CONFIG (WINDOWS)
# =========================
if os.name == "nt":
# Allow override via env var; keep a sensible Windows default.
pytesseract.pytesseract.tesseract_cmd = os.getenv(
"TESSERACT_CMD", r"C:\Program Files\Tesseract-OCR\tesseract.exe"
)
os.environ.setdefault("TESSDATA_PREFIX", r"C:\Program Files\Tesseract-OCR\tessdata")
# =========================
# MODEL PATHS & CONFIG
# =========================
BASE_MODEL = "prajjwal1/bert-tiny"
ADAPTER_PATH = "./lora_adapter"
MAX_LENGTH = 256
LABELS = [
"Employment Letter",
"Lease / Agreement",
"Bank Statements",
"Paystub / Payslip",
"Property Tax",
"Investment",
"Tax Documents",
"Other Documents",
"ID / License",
]
label2id = {label: i for i, label in enumerate(LABELS)}
id2label = {i: label for label, i in label2id.items()}
NUM_LABELS = len(LABELS)
# =========================
# LOAD MODEL
# =========================
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
base_model = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL,
num_labels=NUM_LABELS,
id2label=id2label,
label2id=label2id
)
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
# =========================
# OCR TEXT CLEANING
# =========================
def clean_ocr_text(text):
text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text)
text = re.sub(r"[^A-Za-z0-9.,$ ]", "", text)
return text.strip()
# =========================
# PDF → OCR
# =========================
def extract_text_from_pdf(pdf_bytes, dpi=300, progress_cb=None):
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
extracted_text = ""
total_pages = doc.page_count or 0
for page_num, page in enumerate(doc):
pix = page.get_pixmap(dpi=int(dpi))
img = Image.open(io.BytesIO(pix.tobytes("png")))
text = pytesseract.image_to_string(img, lang="eng")
extracted_text += f"\n--- Page {page_num + 1} ---\n{text}"
if progress_cb and total_pages > 0:
progress_cb(page_num + 1, total_pages)
return clean_ocr_text(extracted_text)
# =========================
# PREDICTION
# =========================
def predict(text, top_k=3, max_length=MAX_LENGTH):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=int(max_length),
)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1)
top_probs, top_indices = torch.topk(probs, k=top_k, dim=1)
top_labels = [model.config.id2label[idx.item()] for idx in top_indices[0]]
top_confidences = [p.item() * 100 for p in top_probs[0]]
return list(zip(top_labels, top_confidences))
# =========================
# STREAMLIT UI
# =========================
def _inject_ui_css():
st.markdown(
"""
""",
unsafe_allow_html=True,
)
def _hero():
# Streamlit 1.27.0 (pinned in Docker) does not support `vertical_alignment=...`.
left, right = st.columns([1.35, 1])
with left:
st.markdown(
"""
⚡ OCR + LoRA TinyBERT • Production UI
Document Classification, done fast.
Upload a scanned PDF (or paste text), extract OCR, and get the top predictions with confidence —
in a clean, modern dashboard.
""",
unsafe_allow_html=True,
)
with right:
# Lottie animation (works best with internet; safely degrades if blocked)
components.html(
"""
""",
height=290,
)
def _render_predictions(predictions):
rows = []
for label, conf in predictions:
w = max(0.0, min(100.0, float(conf)))
# IMPORTANT: no leading indentation here — Markdown treats indented HTML as a code block.
safe_label = _html.escape(str(label))
rows.append(
f''
f'
{safe_label}
'
f'
'
f'
{w:.1f}%
'
f"
"
)
st.markdown(f"{''.join(rows)}
", unsafe_allow_html=True)
_inject_ui_css()
with st.sidebar:
st.markdown("### Controls")
top_k = st.slider("Top-K predictions", min_value=1, max_value=5, value=3, help="How many labels to show.")
ocr_dpi = st.slider(
"OCR quality (DPI)",
min_value=150,
max_value=400,
value=250,
step=25,
help="Higher DPI improves OCR but increases processing time.",
)
preview_chars = st.slider(
"Preview length",
min_value=500,
max_value=8000,
value=3000,
step=500,
help="How much OCR text to show in the preview box.",
)
st.markdown("---")
st.caption("Tip: For low-quality scans, increase DPI and re-run OCR.")
_hero()
st.write("")
tab_upload, tab_paste = st.tabs(["Upload PDF", "Paste Text"])
with tab_upload:
st.markdown("", unsafe_allow_html=True)
st.markdown("#### Upload a scanned PDF")
uploaded_file = st.file_uploader(
"Upload PDF",
type=["pdf"],
label_visibility="collapsed",
help="PDF should contain scanned pages or images. We'll OCR it, then classify.",
)
st.markdown("
", unsafe_allow_html=True)
if uploaded_file:
# IMPORTANT: Streamlit reruns the script on every interaction. If we OCR inside this block,
# clicking "Classify" would re-trigger OCR and feel like it's stuck. So we store OCR output
# in session_state keyed by (file hash + DPI).
pdf_bytes = uploaded_file.getvalue()
file_hash = hashlib.sha256(pdf_bytes).hexdigest()[:16]
ocr_key = f"{file_hash}:{int(ocr_dpi)}"
if st.session_state.get("ocr_key") != ocr_key:
st.session_state["ocr_key"] = ocr_key
st.session_state["ocr_text"] = None
st.session_state["ocr_seconds"] = None
extracted_text = st.session_state.get("ocr_text")
col_run, col_hint = st.columns([1, 2.2])
with col_run:
run_ocr = st.button("Run OCR", use_container_width=True, key="run_ocr_btn")
with col_hint:
st.markdown(
""
"Tip
"
"OCR is the slowest part. Run it once, then classify instantly. "
"Lower DPI = faster OCR."
"
",
unsafe_allow_html=True,
)
if run_ocr or (extracted_text is None and st.session_state.get("auto_ocr_once") is None):
# Auto-run OCR once on first upload to keep UX smooth, but never re-run on button clicks.
st.session_state["auto_ocr_once"] = True
# `text=` was added to st.progress in later Streamlit versions; keep compatible with 1.27.0.
prog = st.progress(0)
prog_text = st.empty()
prog_text.caption("Running OCR…")
def _cb(done, total):
pct = int((done / total) * 100)
prog.progress(pct)
prog_text.caption(f"Running OCR… {done}/{total} pages")
t0 = time.time()
with st.spinner("Extracting text with Tesseract…"):
extracted_text = extract_text_from_pdf(pdf_bytes, dpi=ocr_dpi, progress_cb=_cb)
st.session_state["ocr_text"] = extracted_text
st.session_state["ocr_seconds"] = max(0.0, time.time() - t0)
prog.empty()
prog_text.empty()
extracted_text = st.session_state.get("ocr_text")
if extracted_text:
secs = st.session_state.get("ocr_seconds")
if secs is not None:
st.caption(f"OCR completed in {secs:.1f}s • DPI {int(ocr_dpi)}")
with st.expander("Extracted text preview", expanded=False):
st.text_area("OCR Output", extracted_text[:preview_chars], height=260, label_visibility="collapsed")
col_a, col_b = st.columns([1, 2.2])
with col_a:
run = st.button("Classify document", use_container_width=True, key="classify_doc_btn")
with col_b:
st.markdown(
""
"What happens next?
"
"We tokenize the OCR text and run your LoRA-adapted TinyBERT classifier."
"
",
unsafe_allow_html=True,
)
if run:
with st.spinner("Running model inference…"):
predictions = predict(extracted_text, top_k=top_k, max_length=MAX_LENGTH)
st.markdown("### Results")
_render_predictions(predictions)
else:
st.info("Click **Run OCR** to extract text, then you can classify the document.")
with tab_paste:
st.markdown("", unsafe_allow_html=True)
st.markdown("#### Paste text (skip OCR)")
pasted = st.text_area(
"Paste text",
placeholder="Paste document text here…",
height=220,
label_visibility="collapsed",
)
col1, col2 = st.columns([1, 3])
with col1:
run_text = st.button("Classify text", use_container_width=True)
with col2:
st.caption("Useful for already-digital documents, emails, or copied text.")
st.markdown("
", unsafe_allow_html=True)
if run_text and pasted.strip():
with st.spinner("Running model inference…"):
predictions = predict(clean_ocr_text(pasted), top_k=top_k, max_length=MAX_LENGTH)
st.markdown("### Results")
_render_predictions(predictions)