raahinaez's picture
Update app.py
4aec666 verified
import os
import io
import re
import html as _html
import hashlib
import time
import streamlit as st
import streamlit.components.v1 as components
import torch
import pytesseract
import fitz # PyMuPDF
from PIL import Image
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
import torch.nn.functional as F
# =========================
# STREAMLIT PAGE CONFIG (MUST BE FIRST STREAMLIT COMMAND)
# =========================
def _configure_page():
"""
Configure Streamlit page settings.
This file can be executed directly (streamlit run streamlit_app.py) OR
imported by `app.py` (Docker entrypoint). Streamlit only allows calling
`st.set_page_config()` once per page, so we make this idempotent.
"""
try:
st.set_page_config(
page_title="OCR Document Classifier",
page_icon="📄",
layout="wide",
)
except Exception:
# If another module (e.g., `app.py`) already called set_page_config,
# Streamlit raises StreamlitAPIException. Ignore to prevent a crash.
pass
_configure_page()
# =========================
# TESSERACT CONFIG (WINDOWS)
# =========================
if os.name == "nt":
# Allow override via env var; keep a sensible Windows default.
pytesseract.pytesseract.tesseract_cmd = os.getenv(
"TESSERACT_CMD", r"C:\Program Files\Tesseract-OCR\tesseract.exe"
)
os.environ.setdefault("TESSDATA_PREFIX", r"C:\Program Files\Tesseract-OCR\tessdata")
# =========================
# MODEL PATHS & CONFIG
# =========================
BASE_MODEL = "prajjwal1/bert-tiny"
ADAPTER_PATH = "./lora_adapter"
MAX_LENGTH = 256
LABELS = [
"Employment Letter",
"Lease / Agreement",
"Bank Statements",
"Paystub / Payslip",
"Property Tax",
"Investment",
"Tax Documents",
"Other Documents",
"ID / License",
]
label2id = {label: i for i, label in enumerate(LABELS)}
id2label = {i: label for label, i in label2id.items()}
NUM_LABELS = len(LABELS)
# =========================
# LOAD MODEL
# =========================
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
base_model = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL,
num_labels=NUM_LABELS,
id2label=id2label,
label2id=label2id
)
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
# =========================
# OCR TEXT CLEANING
# =========================
def clean_ocr_text(text):
text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text)
text = re.sub(r"[^A-Za-z0-9.,$ ]", "", text)
return text.strip()
# =========================
# PDF → OCR
# =========================
def extract_text_from_pdf(pdf_bytes, dpi=300, progress_cb=None):
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
extracted_text = ""
total_pages = doc.page_count or 0
for page_num, page in enumerate(doc):
pix = page.get_pixmap(dpi=int(dpi))
img = Image.open(io.BytesIO(pix.tobytes("png")))
text = pytesseract.image_to_string(img, lang="eng")
extracted_text += f"\n--- Page {page_num + 1} ---\n{text}"
if progress_cb and total_pages > 0:
progress_cb(page_num + 1, total_pages)
return clean_ocr_text(extracted_text)
# =========================
# PREDICTION
# =========================
def predict(text, top_k=3, max_length=MAX_LENGTH):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=int(max_length),
)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1)
top_probs, top_indices = torch.topk(probs, k=top_k, dim=1)
top_labels = [model.config.id2label[idx.item()] for idx in top_indices[0]]
top_confidences = [p.item() * 100 for p in top_probs[0]]
return list(zip(top_labels, top_confidences))
# =========================
# STREAMLIT UI
# =========================
def _inject_ui_css():
st.markdown(
"""
<style>
/* ---- App canvas ---- */
.stApp {
background:
radial-gradient(1200px 600px at 10% 10%, rgba(124,58,237,0.25), transparent 55%),
radial-gradient(900px 500px at 90% 15%, rgba(34,197,94,0.18), transparent 60%),
radial-gradient(900px 700px at 50% 90%, rgba(59,130,246,0.16), transparent 55%),
linear-gradient(180deg, #070B18 0%, #0B1020 45%, #070B18 100%);
color: #E5E7EB;
font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Apple Color Emoji", "Segoe UI Emoji";
}
/* Make the content width feel 'product' */
section.main > div.block-container { padding-top: 1.25rem; padding-bottom: 2.5rem; max-width: 1100px; }
/* ---- Glass cards ---- */
.glass {
background: rgba(255,255,255,0.06);
border: 1px solid rgba(255,255,255,0.10);
box-shadow: 0 10px 30px rgba(0,0,0,0.35);
border-radius: 18px;
padding: 18px 18px;
backdrop-filter: blur(10px);
}
.hero-title {
font-size: 2.25rem;
line-height: 1.15;
font-weight: 800;
letter-spacing: -0.02em;
margin: 0 0 0.35rem 0;
}
.hero-sub {
color: rgba(229,231,235,0.78);
font-size: 1.02rem;
margin: 0 0 0.85rem 0;
}
.badge {
display: inline-flex;
align-items: center;
gap: 8px;
padding: 6px 10px;
border-radius: 999px;
background: rgba(124,58,237,0.20);
border: 1px solid rgba(124,58,237,0.35);
color: rgba(243,244,246,0.95);
font-size: 0.85rem;
margin-bottom: 10px;
}
/* Subtle entrance animation */
@keyframes fadeUp {
0% { opacity: 0; transform: translateY(10px); }
100% { opacity: 1; transform: translateY(0); }
}
.fade-up { animation: fadeUp 520ms ease-out both; }
/* ---- File uploader dropzone ---- */
[data-testid="stFileUploaderDropzone"] {
border: 1px dashed rgba(255,255,255,0.25) !important;
background: rgba(255,255,255,0.04) !important;
border-radius: 18px !important;
padding: 22px !important;
}
[data-testid="stFileUploaderDropzone"]:hover {
border-color: rgba(124,58,237,0.7) !important;
box-shadow: 0 0 0 4px rgba(124,58,237,0.18);
}
/* ---- Buttons ---- */
.stButton > button {
border: 1px solid rgba(255,255,255,0.14);
background: linear-gradient(135deg, rgba(124,58,237,0.95), rgba(59,130,246,0.88));
color: white;
border-radius: 14px;
padding: 0.7rem 1rem;
font-weight: 700;
transition: transform 120ms ease, filter 120ms ease, box-shadow 120ms ease;
box-shadow: 0 10px 22px rgba(0,0,0,0.35);
}
.stButton > button:hover {
transform: translateY(-1px);
filter: brightness(1.05);
box-shadow: 0 14px 26px rgba(0,0,0,0.45);
}
/* ---- Text areas & inputs ---- */
textarea, input {
border-radius: 14px !important;
}
/* ---- Animated results bars ---- */
.result-row {
display: grid;
grid-template-columns: 140px 1fr 70px;
align-items: center;
gap: 12px;
padding: 10px 12px;
border-radius: 14px;
background: rgba(255,255,255,0.04);
border: 1px solid rgba(255,255,255,0.08);
margin: 10px 0;
}
.result-label { font-weight: 700; color: rgba(243,244,246,0.95); }
.bar-wrap {
height: 12px;
border-radius: 999px;
background: rgba(255,255,255,0.08);
overflow: hidden;
border: 1px solid rgba(255,255,255,0.10);
}
.bar {
height: 100%;
width: 0%;
border-radius: 999px;
background: linear-gradient(90deg, rgba(34,197,94,0.95), rgba(124,58,237,0.95), rgba(59,130,246,0.95));
animation: grow 900ms cubic-bezier(.2,.8,.2,1) forwards;
}
@keyframes grow { to { width: var(--w); } }
.result-pct { font-variant-numeric: tabular-nums; color: rgba(229,231,235,0.82); text-align: right; }
/* ---- Footer ---- */
.footer {
margin-top: 28px;
color: rgba(229,231,235,0.55);
font-size: 0.9rem;
}
</style>
""",
unsafe_allow_html=True,
)
def _hero():
# Streamlit 1.27.0 (pinned in Docker) does not support `vertical_alignment=...`.
left, right = st.columns([1.35, 1])
with left:
st.markdown(
"""
<div class="glass fade-up">
<div class="badge">⚡ OCR + LoRA TinyBERT • Production UI</div>
<div class="hero-title">Document Classification, done fast.</div>
<div class="hero-sub">
Upload a scanned PDF (or paste text), extract OCR, and get the top predictions with confidence —
in a clean, modern dashboard.
</div>
</div>
""",
unsafe_allow_html=True,
)
with right:
# Lottie animation (works best with internet; safely degrades if blocked)
components.html(
"""
<div class="glass fade-up" style="padding: 8px 10px;">
<script src="https://unpkg.com/@lottiefiles/lottie-player@latest/dist/lottie-player.js"></script>
<lottie-player
src="https://assets10.lottiefiles.com/packages/lf20_1LhsaB.json"
background="transparent"
speed="1"
style="width: 100%; height: 260px;"
loop
autoplay>
</lottie-player>
</div>
""",
height=290,
)
def _render_predictions(predictions):
rows = []
for label, conf in predictions:
w = max(0.0, min(100.0, float(conf)))
# IMPORTANT: no leading indentation here — Markdown treats indented HTML as a code block.
safe_label = _html.escape(str(label))
rows.append(
f'<div class="result-row">'
f'<div class="result-label">{safe_label}</div>'
f'<div class="bar-wrap"><div class="bar" style="--w: {w:.2f}%;"></div></div>'
f'<div class="result-pct">{w:.1f}%</div>'
f"</div>"
)
st.markdown(f"<div class='glass fade-up'>{''.join(rows)}</div>", unsafe_allow_html=True)
_inject_ui_css()
with st.sidebar:
st.markdown("### Controls")
top_k = st.slider("Top-K predictions", min_value=1, max_value=5, value=3, help="How many labels to show.")
ocr_dpi = st.slider(
"OCR quality (DPI)",
min_value=150,
max_value=400,
value=250,
step=25,
help="Higher DPI improves OCR but increases processing time.",
)
preview_chars = st.slider(
"Preview length",
min_value=500,
max_value=8000,
value=3000,
step=500,
help="How much OCR text to show in the preview box.",
)
st.markdown("---")
st.caption("Tip: For low-quality scans, increase DPI and re-run OCR.")
_hero()
st.write("")
tab_upload, tab_paste = st.tabs(["Upload PDF", "Paste Text"])
with tab_upload:
st.markdown("<div class='glass fade-up'>", unsafe_allow_html=True)
st.markdown("#### Upload a scanned PDF")
uploaded_file = st.file_uploader(
"Upload PDF",
type=["pdf"],
label_visibility="collapsed",
help="PDF should contain scanned pages or images. We'll OCR it, then classify.",
)
st.markdown("</div>", unsafe_allow_html=True)
if uploaded_file:
# IMPORTANT: Streamlit reruns the script on every interaction. If we OCR inside this block,
# clicking "Classify" would re-trigger OCR and feel like it's stuck. So we store OCR output
# in session_state keyed by (file hash + DPI).
pdf_bytes = uploaded_file.getvalue()
file_hash = hashlib.sha256(pdf_bytes).hexdigest()[:16]
ocr_key = f"{file_hash}:{int(ocr_dpi)}"
if st.session_state.get("ocr_key") != ocr_key:
st.session_state["ocr_key"] = ocr_key
st.session_state["ocr_text"] = None
st.session_state["ocr_seconds"] = None
extracted_text = st.session_state.get("ocr_text")
col_run, col_hint = st.columns([1, 2.2])
with col_run:
run_ocr = st.button("Run OCR", use_container_width=True, key="run_ocr_btn")
with col_hint:
st.markdown(
"<div class='glass fade-up' style='padding: 14px 16px;'>"
"<b>Tip</b><br/>"
"OCR is the slowest part. Run it once, then classify instantly. "
"Lower DPI = faster OCR."
"</div>",
unsafe_allow_html=True,
)
if run_ocr or (extracted_text is None and st.session_state.get("auto_ocr_once") is None):
# Auto-run OCR once on first upload to keep UX smooth, but never re-run on button clicks.
st.session_state["auto_ocr_once"] = True
# `text=` was added to st.progress in later Streamlit versions; keep compatible with 1.27.0.
prog = st.progress(0)
prog_text = st.empty()
prog_text.caption("Running OCR…")
def _cb(done, total):
pct = int((done / total) * 100)
prog.progress(pct)
prog_text.caption(f"Running OCR… {done}/{total} pages")
t0 = time.time()
with st.spinner("Extracting text with Tesseract…"):
extracted_text = extract_text_from_pdf(pdf_bytes, dpi=ocr_dpi, progress_cb=_cb)
st.session_state["ocr_text"] = extracted_text
st.session_state["ocr_seconds"] = max(0.0, time.time() - t0)
prog.empty()
prog_text.empty()
extracted_text = st.session_state.get("ocr_text")
if extracted_text:
secs = st.session_state.get("ocr_seconds")
if secs is not None:
st.caption(f"OCR completed in {secs:.1f}s • DPI {int(ocr_dpi)}")
with st.expander("Extracted text preview", expanded=False):
st.text_area("OCR Output", extracted_text[:preview_chars], height=260, label_visibility="collapsed")
col_a, col_b = st.columns([1, 2.2])
with col_a:
run = st.button("Classify document", use_container_width=True, key="classify_doc_btn")
with col_b:
st.markdown(
"<div class='glass fade-up' style='padding: 14px 16px;'>"
"<b>What happens next?</b><br/>"
"We tokenize the OCR text and run your LoRA-adapted TinyBERT classifier."
"</div>",
unsafe_allow_html=True,
)
if run:
with st.spinner("Running model inference…"):
predictions = predict(extracted_text, top_k=top_k, max_length=MAX_LENGTH)
st.markdown("### Results")
_render_predictions(predictions)
else:
st.info("Click **Run OCR** to extract text, then you can classify the document.")
with tab_paste:
st.markdown("<div class='glass fade-up'>", unsafe_allow_html=True)
st.markdown("#### Paste text (skip OCR)")
pasted = st.text_area(
"Paste text",
placeholder="Paste document text here…",
height=220,
label_visibility="collapsed",
)
col1, col2 = st.columns([1, 3])
with col1:
run_text = st.button("Classify text", use_container_width=True)
with col2:
st.caption("Useful for already-digital documents, emails, or copied text.")
st.markdown("</div>", unsafe_allow_html=True)
if run_text and pasted.strip():
with st.spinner("Running model inference…"):
predictions = predict(clean_ocr_text(pasted), top_k=top_k, max_length=MAX_LENGTH)
st.markdown("### Results")
_render_predictions(predictions)