Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,6 +2,8 @@ import os
|
|
| 2 |
import io
|
| 3 |
import re
|
| 4 |
import html as _html
|
|
|
|
|
|
|
| 5 |
import streamlit as st
|
| 6 |
import streamlit.components.v1 as components
|
| 7 |
import torch
|
|
@@ -332,7 +334,7 @@ with st.sidebar:
|
|
| 332 |
"OCR quality (DPI)",
|
| 333 |
min_value=150,
|
| 334 |
max_value=400,
|
| 335 |
-
value=
|
| 336 |
step=25,
|
| 337 |
help="Higher DPI improves OCR but increases processing time.",
|
| 338 |
)
|
|
@@ -364,42 +366,84 @@ with tab_upload:
|
|
| 364 |
st.markdown("</div>", unsafe_allow_html=True)
|
| 365 |
|
| 366 |
if uploaded_file:
|
| 367 |
-
|
| 368 |
-
#
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
col_a, col_b = st.columns([1, 2.2])
|
| 387 |
-
with col_a:
|
| 388 |
-
run = st.button("Classify document", use_container_width=True)
|
| 389 |
-
with col_b:
|
| 390 |
st.markdown(
|
| 391 |
"<div class='glass fade-up' style='padding: 14px 16px;'>"
|
| 392 |
-
"<b>
|
| 393 |
-
"
|
|
|
|
| 394 |
"</div>",
|
| 395 |
unsafe_allow_html=True,
|
| 396 |
)
|
| 397 |
|
| 398 |
-
if
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
with tab_paste:
|
| 405 |
st.markdown("<div class='glass fade-up'>", unsafe_allow_html=True)
|
|
|
|
| 2 |
import io
|
| 3 |
import re
|
| 4 |
import html as _html
|
| 5 |
+
import hashlib
|
| 6 |
+
import time
|
| 7 |
import streamlit as st
|
| 8 |
import streamlit.components.v1 as components
|
| 9 |
import torch
|
|
|
|
| 334 |
"OCR quality (DPI)",
|
| 335 |
min_value=150,
|
| 336 |
max_value=400,
|
| 337 |
+
value=250,
|
| 338 |
step=25,
|
| 339 |
help="Higher DPI improves OCR but increases processing time.",
|
| 340 |
)
|
|
|
|
| 366 |
st.markdown("</div>", unsafe_allow_html=True)
|
| 367 |
|
| 368 |
if uploaded_file:
|
| 369 |
+
# IMPORTANT: Streamlit reruns the script on every interaction. If we OCR inside this block,
|
| 370 |
+
# clicking "Classify" would re-trigger OCR and feel like it's stuck. So we store OCR output
|
| 371 |
+
# in session_state keyed by (file hash + DPI).
|
| 372 |
+
pdf_bytes = uploaded_file.getvalue()
|
| 373 |
+
file_hash = hashlib.sha256(pdf_bytes).hexdigest()[:16]
|
| 374 |
+
ocr_key = f"{file_hash}:{int(ocr_dpi)}"
|
| 375 |
+
|
| 376 |
+
if st.session_state.get("ocr_key") != ocr_key:
|
| 377 |
+
st.session_state["ocr_key"] = ocr_key
|
| 378 |
+
st.session_state["ocr_text"] = None
|
| 379 |
+
st.session_state["ocr_seconds"] = None
|
| 380 |
+
|
| 381 |
+
extracted_text = st.session_state.get("ocr_text")
|
| 382 |
+
|
| 383 |
+
col_run, col_hint = st.columns([1, 2.2])
|
| 384 |
+
with col_run:
|
| 385 |
+
run_ocr = st.button("Run OCR", use_container_width=True, key="run_ocr_btn")
|
| 386 |
+
with col_hint:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
st.markdown(
|
| 388 |
"<div class='glass fade-up' style='padding: 14px 16px;'>"
|
| 389 |
+
"<b>Tip</b><br/>"
|
| 390 |
+
"OCR is the slowest part. Run it once, then classify instantly. "
|
| 391 |
+
"Lower DPI = faster OCR."
|
| 392 |
"</div>",
|
| 393 |
unsafe_allow_html=True,
|
| 394 |
)
|
| 395 |
|
| 396 |
+
if run_ocr or (extracted_text is None and st.session_state.get("auto_ocr_once") is None):
|
| 397 |
+
# Auto-run OCR once on first upload to keep UX smooth, but never re-run on button clicks.
|
| 398 |
+
st.session_state["auto_ocr_once"] = True
|
| 399 |
+
|
| 400 |
+
# `text=` was added to st.progress in later Streamlit versions; keep compatible with 1.27.0.
|
| 401 |
+
prog = st.progress(0)
|
| 402 |
+
prog_text = st.empty()
|
| 403 |
+
prog_text.caption("Running OCR…")
|
| 404 |
+
|
| 405 |
+
def _cb(done, total):
|
| 406 |
+
pct = int((done / total) * 100)
|
| 407 |
+
prog.progress(pct)
|
| 408 |
+
prog_text.caption(f"Running OCR… {done}/{total} pages")
|
| 409 |
+
|
| 410 |
+
t0 = time.time()
|
| 411 |
+
with st.spinner("Extracting text with Tesseract…"):
|
| 412 |
+
extracted_text = extract_text_from_pdf(pdf_bytes, dpi=ocr_dpi, progress_cb=_cb)
|
| 413 |
+
st.session_state["ocr_text"] = extracted_text
|
| 414 |
+
st.session_state["ocr_seconds"] = max(0.0, time.time() - t0)
|
| 415 |
+
|
| 416 |
+
prog.empty()
|
| 417 |
+
prog_text.empty()
|
| 418 |
+
|
| 419 |
+
extracted_text = st.session_state.get("ocr_text")
|
| 420 |
+
if extracted_text:
|
| 421 |
+
secs = st.session_state.get("ocr_seconds")
|
| 422 |
+
if secs is not None:
|
| 423 |
+
st.caption(f"OCR completed in {secs:.1f}s • DPI {int(ocr_dpi)}")
|
| 424 |
+
|
| 425 |
+
with st.expander("Extracted text preview", expanded=False):
|
| 426 |
+
st.text_area("OCR Output", extracted_text[:preview_chars], height=260, label_visibility="collapsed")
|
| 427 |
+
|
| 428 |
+
col_a, col_b = st.columns([1, 2.2])
|
| 429 |
+
with col_a:
|
| 430 |
+
run = st.button("Classify document", use_container_width=True, key="classify_doc_btn")
|
| 431 |
+
with col_b:
|
| 432 |
+
st.markdown(
|
| 433 |
+
"<div class='glass fade-up' style='padding: 14px 16px;'>"
|
| 434 |
+
"<b>What happens next?</b><br/>"
|
| 435 |
+
"We tokenize the OCR text and run your LoRA-adapted TinyBERT classifier."
|
| 436 |
+
"</div>",
|
| 437 |
+
unsafe_allow_html=True,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
if run:
|
| 441 |
+
with st.spinner("Running model inference…"):
|
| 442 |
+
predictions = predict(extracted_text, top_k=top_k, max_length=MAX_LENGTH)
|
| 443 |
+
st.markdown("### Results")
|
| 444 |
+
_render_predictions(predictions)
|
| 445 |
+
else:
|
| 446 |
+
st.info("Click **Run OCR** to extract text, then you can classify the document.")
|
| 447 |
|
| 448 |
with tab_paste:
|
| 449 |
st.markdown("<div class='glass fade-up'>", unsafe_allow_html=True)
|