UniversalOcrApp / src /streamlit_app.py
Geraldine's picture
Update src/streamlit_app.py
dd2a65f verified
import spaces
import streamlit as st
from PIL import Image
import torch
from huggingface_hub import snapshot_download
from transformers import (
AutoProcessor,
AutoModelForImageTextToText,
AutoModelForCausalLM,
)
import json
import cv2
import numpy as np
import pandas as pd
from io import BytesIO, StringIO
import datetime
from enum import Enum
from typing import Optional, Tuple
# ----------------------------
# Optional dependency (recommended for dots.ocr)
# ----------------------------
try:
from qwen_vl_utils import process_vision_info # type: ignore
except Exception:
process_vision_info = None
# ========================================
# MODEL TYPES
# ========================================
class OCRModel(str, Enum):
GLM_OCR = "glm_ocr"
DOTS_OCR = "dots_ocr"
MODEL_UI = {
OCRModel.GLM_OCR: {"name": "GLM-OCR", "icon": "🟦", "hf_id": "zai-org/GLM-OCR"},
OCRModel.DOTS_OCR: {"name": "dots.ocr", "icon": "🟩", "hf_id": "rednote-hilab/dots.ocr"},
}
# ========================================
# DOCUMENT TYPES & TEMPLATES
# ========================================
class DocumentType(str, Enum):
"""Supported document types"""
GENERAL = "general"
FULL_JSON_SCHEMA = "full_json_schema"
SIMPLE_TITLE_JSON = "simple_title_json"
LOCALIZED_TITLE_JSON = "localized_title_json"
GROUNDED_TITLE_JSON = "grounded_title_json"
HANDWRITTEN = "handwritten"
DOCUMENT_TEMPLATES = {
DocumentType.GENERAL: {
"name": "General Text",
"description": "Extract all text from any document",
"prompt": "Extract all text from this image. Preserve the layout and structure. Output plain text.",
"icon": "📄"
},
DocumentType.FULL_JSON_SCHEMA: {
"name": "Full Json Schema",
"description": "Extract structured data from this cover page",
"prompt": """Analyze this thesis/dissertation cover image and extract ONLY visible information.
CRITICAL: Only extract information that is CLEARLY VISIBLE on the page.
DO NOT invent, guess, or hallucinate any data. If a field is not visible, use null.
Return ONLY valid JSON with this exact structure:
{
"title": "Main title of the thesis or dissertation as it appears on the title page",
"subtitle": "Subtitle or remainder of the title, usually following a colon; null if not present",
"author": "Full name of the author (student) who wrote the thesis or dissertation",
"degree_type": "Academic degree sought by the author (e.g. PhD, Doctorate, Master's degree, Master's thesis)",
"discipline": "Academic field or discipline of the thesis if explicitly stated; null if not present. Possible values: Mathématiques|Physics|Biology|others",
"granting_institution": "Institution where the thesis was submitted and the degree is granted (degree-granting institution)",
"doctoral_school": "Doctoral school or graduate program, if explicitly mentioned; null if not present",
"co_tutelle_institutions": "List of institutions involved in a joint supervision or co-tutelle agreement; empty list if none",
"partner_institutions": "List of partner institutions associated with the thesis but not granting the degree; empty list if none",
"defense_year": "Year the thesis or dissertation was defended, in YYYY format; null if not visible",
"defense_place": "City or place where the defense took place, if stated; null if not present",
"thesis_advisor": "Main thesis advisor or supervisor (director of thesis); full name; null if not present",
"co_advisors": "List of co-advisors or co-supervisors if explicitly mentioned; full names; empty list if none",
"jury_president": "President or chair of the thesis examination committee, if specified; null if not present",
"reviewers": "List of reviewers or rapporteurs of the thesis, if specified; full names; empty list if none",
"committee_members": "List of other thesis committee or jury members, excluding advisor and reviewers; full names; empty list if none",
"language": "Language in which the thesis is written, if explicitly stated; null if not present",
"confidence": "Confidence score between 0.0 and 1.0 indicating reliability of the extracted metadata"
}
IMPORTANT: Return null for any field where information is NOT clearly visible.
Return ONLY the JSON, no explanation.""",
"icon": "🆔"
},
DocumentType.SIMPLE_TITLE_JSON: {
"name": "Simple Title Json",
"description": "Extract title from this cover page",
"prompt": """Extract the document title from this cover page.
Output ONLY valid JSON:
{
"title": ""
}""",
"icon": "🧾"
},
DocumentType.LOCALIZED_TITLE_JSON: {
"name": "Localized Title Json",
"description": "Extract localized title from this cover page",
"prompt": """Extract the document title from the middle central block of this cover page.
Output ONLY valid JSON:
{
"title": ""
}""",
"icon": "🧾"
},
DocumentType.GROUNDED_TITLE_JSON: {
"name": "Grounded Title Json",
"description": "Extract localized title from this cover page",
"prompt": """Extract the document title usually located around (0.5015,0.442) from this cover page.
Output ONLY valid JSON:
{
"title": ""
}""",
"icon": "📋"
},
DocumentType.HANDWRITTEN: {
"name": "Handwritten Note",
"description": "Extract text from handwritten documents",
"prompt": "Extract all handwritten text from this image. Output plain text, preserving line breaks.",
"icon": "✍️"
}
}
# ========================================
# MODEL LOADING
# ========================================
@st.cache_resource
def load_glm_ocr():
"""Load GLM-OCR model (cached)"""
model_name = MODEL_UI[OCRModel.GLM_OCR]["hf_id"]
with st.spinner("🔄 Loading GLM-OCR... (first time may take 1–3 minutes)"):
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
low_cpu_mem_usage=True,
trust_remote_code=True
)
if not torch.cuda.is_available():
model = model.to(device)
model.eval()
return processor, model, device
@st.cache_resource
def load_dots_ocr():
"""Load dots.ocr model (cached)"""
repo_id = MODEL_UI[OCRModel.DOTS_OCR]["hf_id"]
# dots.ocr recommends avoiding '.' in local directory names (workaround mentioned in their docs)
model_path = "./models/DotsOCR"
snapshot_download(
repo_id=repo_id,
local_dir=model_path,
local_dir_use_symlinks=False,
)
with st.spinner("🔄 Loading dots.ocr... (first time may take 1–3 minutes)"):
device = "cuda" if torch.cuda.is_available() else "cpu"
#dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_path,
# attn_implementation="flash_attention_2", # optional if available
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
if not torch.cuda.is_available():
model = model.to(device)
model.eval()
return processor, model, device
def get_loaded_model(selected: OCRModel):
"""Return (processor, model, device) for selected model, cached by Streamlit."""
if selected == OCRModel.GLM_OCR:
return load_glm_ocr()
return load_dots_ocr()
# ========================================
# IMAGE PREPROCESSING
# ========================================
def preprocess_image(
image: Image.Image,
enhance_contrast: bool = False,
denoise: bool = False,
sharpen: bool = False,
auto_rotate: bool = False,
prevent_cropping: bool = False
) -> Image.Image:
if prevent_cropping and not auto_rotate:
raise Exception("Auto-Rotate must be enabled when Prevent-Cropping is active")
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
if denoise:
gray = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21)
if enhance_contrast:
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
gray = clahe.apply(gray)
if sharpen:
kernel = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]])
gray = cv2.filter2D(gray, -1, kernel)
if auto_rotate:
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
edges = cv2.Canny(blurred, 50, 150)
lines = cv2.HoughLinesP(edges, 1, np.pi / 180, 100, minLineLength=80, maxLineGap=10)
if lines is not None and len(lines) > 0:
angles = []
for line in lines:
x1, y1, x2, y2 = line[0]
angle = np.degrees(np.arctan2(y2 - y1, x2 - x1))
if -45 < angle < 45:
angles.append(angle)
if angles:
median_angle = float(np.median(angles))
if abs(median_angle) > 45:
median_angle -= 90 * np.sign(median_angle)
(h0, w0) = gray.shape[:2]
center = (w0 // 2, h0 // 2)
M = cv2.getRotationMatrix2D(center, median_angle, 1.0)
out_w, out_h = w0, h0
if prevent_cropping:
cos = np.abs(M[0, 0])
sin = np.abs(M[0, 1])
out_w = int((h0 * sin) + (w0 * cos))
out_h = int((h0 * cos) + (w0 * sin))
M[0, 2] += (out_w / 2) - center[0]
M[1, 2] += (out_h / 2) - center[1]
gray = cv2.warpAffine(
gray,
M,
(out_w, out_h),
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_REPLICATE
)
return Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
# ========================================
# OCR EXTRACTION
# ========================================
def _now_ms() -> int:
return int(datetime.datetime.now().timestamp() * 1000)
def extract_text_glm(
image: Image.Image,
prompt: str,
max_tokens: int,
processor,
model,
device: str
) -> Tuple[str, int]:
start_ms = _now_ms()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
temperature=0.0
)
output_text = processor.decode(
generated_ids[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
del inputs, generated_ids
if torch.cuda.is_available():
torch.cuda.empty_cache()
return output_text, _now_ms() - start_ms
def extract_text_dots(
image: Image.Image,
prompt: str,
max_tokens: int,
processor,
model,
device: str
) -> Tuple[str, int]:
"""
dots.ocr transformers inference (matches their model-card approach):
- apply_chat_template(tokenize=False)
- process_vision_info(messages) to get image_inputs/video_inputs
- processor(text=[...], images=..., videos=..., return_tensors="pt")
- generate, then trim input tokens, then decode
"""
start_ms = _now_ms()
# dots.ocr examples use {"type":"image","image": <path>} but PIL works in practice with most processors
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
if process_vision_info is not None:
image_inputs, video_inputs = process_vision_info(messages)
else:
# Fallback: no video, single image
image_inputs, video_inputs = [image], None
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Some processors return a BatchEncoding with .to(...)
inputs = inputs.to(device)
# some processors add keys that this model doesn't use
unused_keys = ["mm_token_type_ids", "token_type_ids"]
for k in unused_keys:
if k in inputs:
inputs.pop(k)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
temperature=0.0
)
# Trim prompt tokens
in_ids = inputs["input_ids"]
trimmed = []
for i in range(generated_ids.shape[0]):
trimmed.append(generated_ids[i][in_ids.shape[1]:])
output_text = processor.batch_decode(
trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
del inputs, generated_ids, trimmed
if torch.cuda.is_available():
torch.cuda.empty_cache()
return output_text, _now_ms() - start_ms
def extract_text(
selected_model: OCRModel,
image: Image.Image,
prompt: str,
max_tokens: int,
) -> Tuple[str, int]:
processor, model, device = get_loaded_model(selected_model)
if selected_model == OCRModel.GLM_OCR:
return extract_text_glm(image, prompt, max_tokens, processor, model, device)
return extract_text_dots(image, prompt, max_tokens, processor, model, device)
# ========================================
# STREAMLIT UI
# ========================================
st.set_page_config(
page_title="Universal OCR Scanner",
page_icon="🔍",
layout="wide",
initial_sidebar_state="expanded"
)
# Initialize session state
for k, v in {
"should_process": False,
"has_results": False,
"output_text": "",
"processing_time": 0,
"doc_type": DocumentType.GENERAL,
"selected_model": OCRModel.GLM_OCR,
"current_file": None,
}.items():
if k not in st.session_state:
st.session_state[k] = v
st.title("🔍 Universal OCR Scanner")
st.markdown("Extract text and structured data from **any document** - receipts, IDs, invoices, forms, and more!")
with st.sidebar:
st.header("🧠 Model")
selected_model = st.radio(
"Select OCR model:",
options=list(OCRModel),
format_func=lambda x: f"{MODEL_UI[x]['icon']} {MODEL_UI[x]['name']}",
index=list(OCRModel).index(st.session_state.selected_model),
)
st.session_state.selected_model = selected_model
st.header("📋 Document Type")
doc_type = st.radio(
"Select document type:",
options=list(DocumentType),
format_func=lambda x: f"{DOCUMENT_TEMPLATES[x]['icon']} {DOCUMENT_TEMPLATES[x]['name']}",
index=list(DocumentType).index(st.session_state.doc_type),
)
st.session_state.doc_type = doc_type
st.info(DOCUMENT_TEMPLATES[doc_type]['description'])
st.markdown("---")
st.header("⚙️ Image Enhancement")
with st.expander("🎨 Preprocessing Options", expanded=False):
enhance_contrast = st.checkbox("Enhance Contrast", value=False, help="Improve visibility of faded text")
denoise = st.checkbox("Reduce Noise", value=False, help="Remove image noise and artifacts")
sharpen = st.checkbox("Sharpen Text", value=False, help="Make text edges crisper")
auto_rotate = st.checkbox("Auto-Rotate", value=False, help="Automatically straighten tilted documents")
prevent_cropping = st.checkbox("Prevent-Cropping", value=False, help="Prevent cropping when rotate")
st.markdown("---")
with st.expander("🔧 Advanced Options", expanded=False):
show_preprocessed = st.checkbox("Show Preprocessed Image", value=False)
max_tokens = st.slider("Max Output Tokens", 512, 4096, 2048, 256, help="Increase for longer documents")
custom_prompt = st.checkbox("Use Custom Prompt", value=False)
st.markdown("---")
st.caption("💡 **Tips:**")
st.caption("• Use good lighting")
st.caption("• Avoid shadows")
st.caption("• Keep text horizontal")
st.caption("• Use high resolution images")
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("📤 Upload Document")
upload_tab, camera_tab = st.tabs(["📁 Upload File", "📸 Take Photo"])
image: Optional[Image.Image] = None
with upload_tab:
uploaded_file = st.file_uploader(
"Choose an image...",
type=["jpg", "jpeg", "png", "webp"],
help="Supported formats: JPG, PNG, WEBP"
)
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
if st.session_state.current_file != uploaded_file.name:
st.session_state.current_file = uploaded_file.name
st.session_state.has_results = False
with camera_tab:
camera_picture = st.camera_input("Take a photo")
if camera_picture is not None:
image = Image.open(BytesIO(camera_picture.getvalue())).convert("RGB")
st.session_state.has_results = False
if image is not None:
st.image(image, caption="Original Image", width="content")
with col2:
st.subheader("📋 Extraction Settings")
if custom_prompt:
prompt = st.text_area(
"Custom Extraction Prompt:",
value=DOCUMENT_TEMPLATES[doc_type]['prompt'],
height=200,
help="Customize how the OCR extracts data",
key="custom_prompt_text"
)
else:
prompt = DOCUMENT_TEMPLATES[doc_type]['prompt']
st.code(prompt, language="text")
if image is not None:
if st.button("🚀 Extract Text", type="primary", width="content", key="extract_button"):
st.session_state.should_process = True
else:
st.info("👆 Upload or capture an image to begin")
# Processing
if image is not None and st.session_state.get('should_process', False):
st.session_state.should_process = False
with st.spinner("🔄 Processing document..."):
try:
if enhance_contrast or denoise or sharpen or auto_rotate or prevent_cropping:
preprocessed_image = preprocess_image(
image,
enhance_contrast=enhance_contrast,
denoise=denoise,
sharpen=sharpen,
auto_rotate=auto_rotate,
prevent_cropping=prevent_cropping
)
else:
preprocessed_image = image
if show_preprocessed and preprocessed_image != image:
st.subheader("🔧 Preprocessed Image")
col_a, col_b = st.columns(2)
with col_a:
st.image(image, caption="Original", width="content")
with col_b:
st.image(preprocessed_image, caption="Enhanced", width="content")
output_text, processing_time = extract_text(
selected_model=st.session_state.selected_model,
image=preprocessed_image,
prompt=prompt,
max_tokens=max_tokens
)
st.session_state.output_text = output_text
st.session_state.processing_time = processing_time
st.session_state.preprocessed_image = preprocessed_image
st.session_state.has_results = True
except Exception as e:
st.error(f"❌ Error during extraction: {str(e)}")
import traceback
with st.expander("Show Error Details"):
st.code(traceback.format_exc())
st.session_state.has_results = False
# Results
if st.session_state.get('has_results', False):
output_text = st.session_state.output_text
processing_time = st.session_state.processing_time
preprocessed_image = st.session_state.get('preprocessed_image', image)
st.success(f"✅ Extraction complete! ({processing_time}ms)")
is_json = False
parsed_data = None
if doc_type in [
DocumentType.FULL_JSON_SCHEMA,
DocumentType.SIMPLE_TITLE_JSON,
DocumentType.LOCALIZED_TITLE_JSON,
DocumentType.GROUNDED_TITLE_JSON
]:
try:
clean_text = output_text
if "```json" in clean_text:
clean_text = clean_text.split("```json")[1].split("```")[0].strip()
elif "```" in clean_text:
clean_text = clean_text.split("```")[1].split("```")[0].strip()
if len(clean_text) > 50000:
st.warning("⚠️ Detected unusually large JSON output. Truncating...")
clean_text = clean_text[:50000]
parsed_data = json.loads(clean_text)
def flatten_dict(d, max_depth=2, current_depth=0):
if current_depth >= max_depth:
return {}
if not isinstance(d, dict):
return d
flattened = {}
for key, value in d.items():
if isinstance(value, dict):
if current_depth < max_depth - 1:
flattened[key] = flatten_dict(value, max_depth, current_depth + 1)
elif isinstance(value, list):
flattened[key] = value
else:
flattened[key] = value
return flattened
parsed_data = flatten_dict(parsed_data, max_depth=2)
is_json = True
except json.JSONDecodeError:
is_json = False
except Exception as e:
st.warning(f"⚠️ JSON parsing issue: {str(e)}")
is_json = False
st.markdown("---")
st.subheader("📄 Extracted Data")
if is_json and parsed_data:
col_display, col_download = st.columns([2, 1])
with col_display:
for key, value in parsed_data.items():
if isinstance(value, dict):
st.markdown(f"**{key.replace('_', ' ').title()}:**")
for k, v in value.items():
st.text(f" {k}: {v}")
elif isinstance(value, list):
st.markdown(f"**{key.replace('_', ' ').title()}:**")
if value and isinstance(value[0], dict):
df = pd.DataFrame(value)
st.dataframe(df, width="content", hide_index=True)
else:
for item in value:
st.text(f" • {item}")
else:
st.markdown(f"**{key.replace('_', ' ').title()}:** {value}")
with col_download:
st.subheader("💾 Downloads")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
json_str = json.dumps(parsed_data, ensure_ascii=False, indent=2)
st.download_button(
label="📄 JSON",
data=json_str,
file_name=f"{doc_type.value}_{timestamp}.json",
mime="application/json",
width="content"
)
try:
flat_data = {}
for k, v in parsed_data.items():
if isinstance(v, (dict, list)):
flat_data[k] = json.dumps(v, ensure_ascii=False)
else:
flat_data[k] = v
df = pd.DataFrame([flat_data])
csv_buffer = StringIO()
df.to_csv(csv_buffer, index=False, encoding='utf-8')
st.download_button(
label="📊 CSV",
data=csv_buffer.getvalue(),
file_name=f"{doc_type.value}_{timestamp}.csv",
mime="text/csv",
width="content"
)
except Exception:
pass
st.download_button(
label="📝 TXT",
data=output_text,
file_name=f"{doc_type.value}_{timestamp}.txt",
mime="text/plain",
width="content"
)
with st.expander("🔍 View Raw JSON"):
st.json(parsed_data)
else:
st.text_area(
"Extracted Text:",
value=output_text,
height=400,
label_visibility="collapsed"
)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
st.download_button(
label="💾 Download as TXT",
data=output_text,
file_name=f"extracted_text_{timestamp}.txt",
mime="text/plain"
)
# Footer
st.markdown("---")
col_footer1, col_footer2, col_footer3 = st.columns(3)
with col_footer1:
m = st.session_state.selected_model
st.caption(f"⚡ Powered by {MODEL_UI[m]['name']}")
with col_footer2:
# device depends on selected model (they’ll both use same device typically)
_, _, device = get_loaded_model(st.session_state.selected_model)
st.caption(f"🖥️ Device: {device.upper()}")
with col_footer3:
st.caption("🌟 Universal Document Scanner")