UniversalOcrApp / src /streamlit_app.py
zz2232's picture
Update src/streamlit_app.py
b0fc258 verified
import streamlit as st
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
import json
import cv2
import numpy as np
import pandas as pd
from io import BytesIO, StringIO
import datetime
from enum import Enum
from typing import Dict, Any, Optional
import fitz
import io
# ========================================
# DOCUMENT TYPES & TEMPLATES
# ========================================
class DocumentType(str, Enum):
"""Supported document types"""
GENERAL = "general"
ID_CARD = "id_card"
RECEIPT = "receipt"
INVOICE = "invoice"
BUSINESS_CARD = "business_card"
FORM = "form"
HANDWRITTEN = "handwritten"
DOCUMENT_TEMPLATES = {
DocumentType.GENERAL: {
"name": "General Text",
"description": "Extract all text from any document",
"prompt": "Extract all text from this image. Preserve the layout and structure. Output plain text.",
"icon": "πŸ“„"
},
DocumentType.ID_CARD: {
"name": "ID Card / Passport",
"description": "Extract structured data from identity documents",
"prompt": """Extract structured data from this identity document.
Output ONLY valid JSON with these exact fields, no nested objects:
{
"document_type": "",
"full_name": "",
"sex": "",
"date_of_birth": "",
"date_of_expiry": "",
"nationality": "",
"document_number": "",
"place_of_birth": "",
"personal_number": ""
}
IMPORTANT: Do NOT create nested or recursive structures. Keep it flat and simple.""",
"icon": "πŸ†”"
},
DocumentType.RECEIPT: {
"name": "Receipt",
"description": "Extract items, prices, and totals from receipts",
"prompt": """Extract information from this receipt.
Output ONLY valid JSON:
{
"merchant_name": "",
"date": "",
"time": "",
"items": [
{"name": "", "quantity": 1, "price": 0.0}
],
"subtotal": 0.0,
"tax": 0.0,
"total": 0.0,
"payment_method": ""
}""",
"icon": "🧾"
},
DocumentType.INVOICE: {
"name": "Invoice",
"description": "Extract invoice details and line items",
"prompt": """Extract information from this invoice.
Output ONLY valid JSON:
{
"invoice_number": "",
"date": "",
"due_date": "",
"vendor": {
"name": "",
"address": "",
"contact": ""
},
"customer": {
"name": "",
"address": "",
"contact": ""
},
"line_items": [
{"description": "", "quantity": 1, "unit_price": 0.0, "amount": 0.0}
],
"subtotal": 0.0,
"tax": 0.0,
"total": 0.0
}""",
"icon": "πŸ“‹"
},
DocumentType.BUSINESS_CARD: {
"name": "Business Card",
"description": "Extract contact information",
"prompt": """Extract contact information from this business card.
Output ONLY valid JSON:
{
"name": "",
"title": "",
"company": "",
"email": "",
"phone": "",
"mobile": "",
"website": "",
"address": "",
"social_media": {}
}""",
"icon": "πŸ’Ό"
},
DocumentType.FORM: {
"name": "Form",
"description": "Extract filled form data",
"prompt": """Extract all fields and values from this form.
Output ONLY valid JSON with field names as keys and filled values:
{
"field_name": "value"
}""",
"icon": "πŸ“"
},
DocumentType.HANDWRITTEN: {
"name": "Handwritten Note",
"description": "Extract text from handwritten documents",
"prompt": "Extract all handwritten text from this image. Output plain text, preserving line breaks.",
"icon": "✍️"
}
}
# ========================================
# MODEL LOADING
# ========================================
@st.cache_resource
def load_glm_ocr():
"""Load GLM-OCR model (cached)"""
MODEL_NAME = "zai-org/GLM-OCR"
with st.spinner("πŸ”„ Loading OCR model... (first time may take 1–3 minutes)"):
processor = AutoProcessor.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForImageTextToText.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
low_cpu_mem_usage=True,
trust_remote_code=True
)
if not torch.cuda.is_available():
model = model.to(device)
model.eval()
return processor, model, device
processor, model, device = load_glm_ocr()
# ========================================
# IMAGE PREPROCESSING
# ========================================
def preprocess_image(
image: Image.Image,
enhance_contrast: bool = False,
denoise: bool = False,
sharpen: bool = False,
auto_rotate: bool = False,
prevent_cropping: bool = False
) -> Image.Image:
"""
Preprocess image with optional enhancements
Args:
image: PIL Image
enhance_contrast: Apply CLAHE contrast enhancement
denoise: Apply denoising
sharpen: Apply sharpening
auto_rotate: Attempt to auto-rotate text to horizontal
Returns:
Preprocessed PIL Image
"""
if prevent_cropping and not auto_rotate:
raise Exception(f"Auto-Rotate must be enabled when Prevent-Cropping is active")
# Convert to OpenCV format
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
# Denoise
if denoise:
gray = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21)
# Enhance contrast
if enhance_contrast:
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
gray = clahe.apply(gray)
# Sharpen
if sharpen:
kernel = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]])
gray = cv2.filter2D(gray, -1, kernel)
# Auto-rotate (basic implementation)
if auto_rotate:
# blur for calculate the rotation angle correctly
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# Detect lines
edges = cv2.Canny(blurred, 50, 150)
lines = cv2.HoughLinesP(edges, 1, np.pi / 180, 100, minLineLength=80, maxLineGap=10)
if lines is not None and len(lines) > 0:
angles = []
for line in lines:
x1, y1, x2, y2 = line[0]
angle = np.degrees(np.arctan2(y2 - y1, x2 - x1))
if -45 < angle < 45:
angles.append(angle)
if len(angles) > 0:
# Use median to be robust against outliers
median_angle = np.median(angles)
# If lines are detected near vertical, adjust (rare for text)
if abs(median_angle) > 45:
median_angle -= 90 * np.sign(median_angle)
rotation_angle = median_angle # negative to correct back to horizontal
(h, w) = gray.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, rotation_angle, 1.0)
if prevent_cropping:
# Calculate new image size to prevent cropping
cos = np.abs(M[0, 0])
sin = np.abs(M[0, 1])
w = int((h * sin) + (w * cos))
h = int((h * cos) + (w * sin))
# Adjust transformation matrix
M[0, 2] += (w / 2) - center[0]
M[1, 2] += (h / 2) - center[1]
gray = cv2.warpAffine(gray, M, (w, h),
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_REPLICATE)
# Convert back to RGB
return Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
# ========================================
# OCR EXTRACTION
# ========================================
def extract_text(
image: Image.Image,
prompt: str,
max_tokens: int = 2048
) -> tuple[str, int]:
"""
Extract text from image using GLM-OCR
Args:
image: PIL Image
prompt: Extraction prompt
max_tokens: Maximum tokens to generate
Returns:
Tuple of (extracted_text, processing_time_ms)
"""
start_time = datetime.datetime.now()
# Prepare messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
# Apply chat template
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
# Move to device
inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}
# Generate
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
temperature=0.0
)
# Decode
output_text = processor.decode(
generated_ids[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
# Cleanup
del inputs, generated_ids
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Calculate processing time
processing_time = (datetime.datetime.now() - start_time).total_seconds() * 1000
return output_text, int(processing_time)
# ========================================
# STREAMLIT UI
# ========================================
st.set_page_config(
page_title="Universal OCR Scanner",
page_icon="πŸ”",
layout="wide",
initial_sidebar_state="expanded"
)
# Initialize session state
if 'should_process' not in st.session_state:
st.session_state.should_process = False
if 'has_results' not in st.session_state:
st.session_state.has_results = False
if 'output_text' not in st.session_state:
st.session_state.output_text = ""
if 'processing_time' not in st.session_state:
st.session_state.processing_time = 0
if 'doc_type' not in st.session_state:
st.session_state.doc_type = DocumentType.GENERAL
if 'current_file' not in st.session_state:
st.session_state.current_file = None
# Header
st.title("πŸ” Universal OCR Scanner")
st.markdown("Extract text and structured data from **any document** - receipts, IDs, invoices, forms, and more!")
# Sidebar - Document Type Selection
with st.sidebar:
st.header("πŸ“‹ Document Type")
# Show document type cards
doc_type = st.radio(
"Select document type:",
options=list(DocumentType),
format_func=lambda x: f"{DOCUMENT_TEMPLATES[x]['icon']} {DOCUMENT_TEMPLATES[x]['name']}",
label_visibility="collapsed"
)
# Show description
st.info(DOCUMENT_TEMPLATES[doc_type]['description'])
st.markdown("---")
# Preprocessing options
st.header("βš™οΈ Image Enhancement")
with st.expander("🎨 Preprocessing Options", expanded=False):
enhance_contrast = st.checkbox("Enhance Contrast", value=False,
help="Improve visibility of faded text")
denoise = st.checkbox("Reduce Noise", value=False,
help="Remove image noise and artifacts")
sharpen = st.checkbox("Sharpen Text", value=False,
help="Make text edges crisper")
auto_rotate = st.checkbox("Auto-Rotate", value=False,
help="Automatically straighten tilted documents")
prevent_cropping = st.checkbox("Prevent-Cropping", value=False,
help="Prevent cropping when rotate")
st.markdown("---")
# Advanced options
with st.expander("πŸ”§ Advanced Options", expanded=False):
show_preprocessed = st.checkbox("Show Preprocessed Image", value=False)
max_tokens = st.slider("Max Output Tokens", 512, 4096, 2048, 256,
help="Increase for longer documents")
custom_prompt = st.checkbox("Use Custom Prompt", value=False)
st.markdown("---")
# Info
st.caption("πŸ’‘ **Tips:**")
st.caption("β€’ Use good lighting")
st.caption("β€’ Avoid shadows")
st.caption("β€’ Keep text horizontal")
st.caption("β€’ Use high resolution images")
# Main content area
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("πŸ“€ Upload Document")
# Tabs for upload methods
upload_tab, camera_tab = st.tabs(["πŸ“ Upload File", "πŸ“Έ Take Photo"])
image = None
with upload_tab:
uploaded_file = st.file_uploader(
"Choose an image...",
type=["jpg", "jpeg", "png", "webp", "pdf"],
help="Supported formats: JPG, PNG, WEBP, PDF"
)
if uploaded_file is not None:
file_extension = uploaded_file.name.split('.')[-1].lower()
if file_extension == 'pdf':
# open PDF from memory
doc = fitz.open(stream=uploaded_file.read(), filetype="pdf")
page = doc.load_page(0) # Choose First Page (0)
# Render page to image (pixmap)
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better OCR result
# Convert Pixmap to PIL Image Object
img_data = pix.tobytes("png")
image = Image.open(io.BytesIO(img_data)).convert("RGB")
doc.close()
else:
# Normal image files
image = Image.open(uploaded_file).convert("RGB")
# Clear previous results when new image uploaded
if 'current_file' not in st.session_state or st.session_state.current_file != uploaded_file.name:
st.session_state.current_file = uploaded_file.name
st.session_state.has_results = False
with camera_tab:
camera_picture = st.camera_input("Take a photo")
if camera_picture is not None:
image = Image.open(BytesIO(camera_picture.getvalue())).convert("RGB")
# Clear previous results when new photo taken
st.session_state.has_results = False
# Show original image
if image is not None:
st.image(image, caption="Original Image", width="content")
with col2:
st.subheader("πŸ“‹ Extraction Settings")
# Show/edit prompt
if custom_prompt:
prompt = st.text_area(
"Custom Extraction Prompt:",
value=DOCUMENT_TEMPLATES[doc_type]['prompt'],
height=200,
help="Customize how the OCR extracts data",
key="custom_prompt_text"
)
else:
prompt = DOCUMENT_TEMPLATES[doc_type]['prompt']
st.code(prompt, language="text")
# Process button
if image is not None:
if st.button(
"πŸš€ Extract Text",
type="primary",
width="content",
key="extract_button"
):
# Trigger processing by setting session state
st.session_state.should_process = True
else:
st.info("πŸ‘† Upload or capture an image to begin")
# Processing (only run when button is clicked)
if image is not None and st.session_state.get('should_process', False):
# Clear the flag immediately to prevent re-processing on next rerun
st.session_state.should_process = False
with st.spinner("πŸ”„ Processing document..."):
try:
# Preprocess image
if enhance_contrast or denoise or sharpen or auto_rotate or prevent_cropping:
preprocessed_image = preprocess_image(
image,
enhance_contrast=enhance_contrast,
denoise=denoise,
sharpen=sharpen,
auto_rotate=auto_rotate,
prevent_cropping=prevent_cropping
)
else:
preprocessed_image = image
# Show preprocessed if requested
if show_preprocessed and preprocessed_image != image:
st.subheader("πŸ”§ Preprocessed Image")
col_a, col_b = st.columns(2)
with col_a:
st.image(image, caption="Original", width="content")
with col_b:
st.image(preprocessed_image, caption="Enhanced", width="content")
# Extract text
output_text, processing_time = extract_text(
preprocessed_image,
prompt=prompt,
max_tokens=max_tokens
)
# Store results in session state
st.session_state.output_text = output_text
st.session_state.processing_time = processing_time
st.session_state.doc_type = doc_type
st.session_state.preprocessed_image = preprocessed_image
st.session_state.has_results = True
except Exception as e:
st.error(f"❌ Error during extraction: {str(e)}")
import traceback
with st.expander("Show Error Details"):
st.code(traceback.format_exc())
st.session_state.has_results = False
# Display results (separate from processing)
if st.session_state.get('has_results', False):
output_text = st.session_state.output_text
processing_time = st.session_state.processing_time
doc_type = st.session_state.doc_type
preprocessed_image = st.session_state.get('preprocessed_image', image)
# Display success message
st.success(f"βœ… Extraction complete! ({processing_time}ms)")
# Try to parse as JSON for structured documents
is_json = False
parsed_data = None
if doc_type in [DocumentType.ID_CARD, DocumentType.RECEIPT,
DocumentType.INVOICE, DocumentType.BUSINESS_CARD,
DocumentType.FORM]:
try:
# Clean JSON from markdown
clean_text = output_text
if "```json" in clean_text:
clean_text = clean_text.split("```json")[1].split("```")[0].strip()
elif "```" in clean_text:
clean_text = clean_text.split("```")[1].split("```")[0].strip()
# Truncate if too long (likely recursive)
if len(clean_text) > 50000: # Reasonable JSON should be much smaller
st.warning("⚠️ Detected recursive JSON structure. Truncating...")
clean_text = clean_text[:50000]
parsed_data = json.loads(clean_text)
# Flatten recursive structures
def flatten_dict(d, max_depth=2, current_depth=0):
"""Remove recursive nested structures"""
if current_depth >= max_depth:
return {}
if not isinstance(d, dict):
return d
flattened = {}
for key, value in d.items():
if isinstance(value, dict):
# Only keep first level of nesting
if current_depth < max_depth - 1:
flattened[key] = flatten_dict(value, max_depth, current_depth + 1)
# Skip deeply nested structures
elif isinstance(value, list):
# Keep lists but limit depth
flattened[key] = value
else:
flattened[key] = value
return flattened
# Flatten the parsed data
parsed_data = flatten_dict(parsed_data, max_depth=2)
is_json = True
except json.JSONDecodeError:
is_json = False
except Exception as e:
st.warning(f"⚠️ JSON parsing issue: {str(e)}")
is_json = False
# Display based on type
st.markdown("---")
st.subheader("πŸ“„ Extracted Data")
if is_json and parsed_data:
# Structured data display
col_display, col_download = st.columns([2, 1])
with col_display:
# Format display based on document type
if doc_type == DocumentType.RECEIPT:
st.markdown("### 🧾 Receipt Details")
# Merchant info
if "merchant_name" in parsed_data:
st.markdown(f"**Merchant:** {parsed_data['merchant_name']}")
if "date" in parsed_data:
st.markdown(f"**Date:** {parsed_data['date']}")
if "time" in parsed_data:
st.markdown(f"**Time:** {parsed_data['time']}")
# Items table
if "items" in parsed_data and parsed_data["items"]:
st.markdown("**Items:**")
items_df = pd.DataFrame(parsed_data["items"])
st.dataframe(items_df, width="content", hide_index=True)
# Totals
st.markdown("---")
if "subtotal" in parsed_data:
st.markdown(f"**Subtotal:** ${parsed_data['subtotal']:.2f}")
if "tax" in parsed_data:
st.markdown(f"**Tax:** ${parsed_data['tax']:.2f}")
if "total" in parsed_data:
st.markdown(f"**Total:** ${parsed_data['total']:.2f}")
elif doc_type == DocumentType.INVOICE:
st.markdown("### πŸ“‹ Invoice Details")
col_inv1, col_inv2 = st.columns(2)
with col_inv1:
st.markdown("**Invoice Info:**")
if "invoice_number" in parsed_data:
st.text(f"Number: {parsed_data['invoice_number']}")
if "date" in parsed_data:
st.text(f"Date: {parsed_data['date']}")
if "due_date" in parsed_data:
st.text(f"Due: {parsed_data['due_date']}")
with col_inv2:
if "vendor" in parsed_data:
st.markdown("**Vendor:**")
vendor = parsed_data["vendor"]
if isinstance(vendor, dict):
for k, v in vendor.items():
if v:
st.text(f"{k.title()}: {v}")
# Line items
if "line_items" in parsed_data and parsed_data["line_items"]:
st.markdown("**Line Items:**")
items_df = pd.DataFrame(parsed_data["line_items"])
st.dataframe(items_df, width="content", hide_index=True)
# Total
if "total" in parsed_data:
st.markdown(f"### **Total: ${parsed_data['total']:.2f}**")
else:
# Generic structured data display
for key, value in parsed_data.items():
if isinstance(value, dict):
st.markdown(f"**{key.replace('_', ' ').title()}:**")
for k, v in value.items():
st.text(f" {k}: {v}")
elif isinstance(value, list):
st.markdown(f"**{key.replace('_', ' ').title()}:**")
if value and isinstance(value[0], dict):
df = pd.DataFrame(value)
st.dataframe(df, width="content", hide_index=True)
else:
for item in value:
st.text(f" β€’ {item}")
else:
st.markdown(f"**{key.replace('_', ' ').title()}:** {value}")
with col_download:
st.subheader("πŸ’Ύ Downloads")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# JSON download
json_str = json.dumps(parsed_data, ensure_ascii=False, indent=2)
st.download_button(
label="πŸ“„ JSON",
data=json_str,
file_name=f"{doc_type.value}_{timestamp}.json",
mime="application/json",
width="content"
)
# CSV download (flattened)
try:
# Flatten nested structures
flat_data = {}
for k, v in parsed_data.items():
if isinstance(v, (dict, list)):
flat_data[k] = json.dumps(v, ensure_ascii=False)
else:
flat_data[k] = v
df = pd.DataFrame([flat_data])
csv_buffer = StringIO()
df.to_csv(csv_buffer, index=False, encoding='utf-8')
st.download_button(
label="πŸ“Š CSV",
data=csv_buffer.getvalue(),
file_name=f"{doc_type.value}_{timestamp}.csv",
mime="text/csv",
width="content"
)
except:
pass
# Raw text download
st.download_button(
label="πŸ“ TXT",
data=output_text,
file_name=f"{doc_type.value}_{timestamp}.txt",
mime="text/plain",
width="content"
)
# Show raw JSON in expander
with st.expander("πŸ” View Raw JSON"):
st.json(parsed_data)
else:
# Plain text display
st.text_area(
"Extracted Text:",
value=output_text,
height=400,
label_visibility="collapsed"
)
# Download
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
st.download_button(
label="πŸ’Ύ Download as TXT",
data=output_text,
file_name=f"extracted_text_{timestamp}.txt",
mime="text/plain"
)
# Footer
st.markdown("---")
col_footer1, col_footer2, col_footer3 = st.columns(3)
with col_footer1:
st.caption("⚑ Powered by GLM-OCR")
with col_footer2:
st.caption(f"πŸ–₯️ Device: {device.upper()}")
with col_footer3:
st.caption("🌟 Universal Document Scanner")