import streamlit as st from PIL import Image import torch from transformers import AutoProcessor, AutoModelForImageTextToText import json import cv2 import numpy as np import pandas as pd from io import BytesIO, StringIO import datetime from enum import Enum from typing import Dict, Any, Optional import fitz import io # ======================================== # DOCUMENT TYPES & TEMPLATES # ======================================== class DocumentType(str, Enum): """Supported document types""" GENERAL = "general" ID_CARD = "id_card" RECEIPT = "receipt" INVOICE = "invoice" BUSINESS_CARD = "business_card" FORM = "form" HANDWRITTEN = "handwritten" DOCUMENT_TEMPLATES = { DocumentType.GENERAL: { "name": "General Text", "description": "Extract all text from any document", "prompt": "Extract all text from this image. Preserve the layout and structure. Output plain text.", "icon": "๐Ÿ“„" }, DocumentType.ID_CARD: { "name": "ID Card / Passport", "description": "Extract structured data from identity documents", "prompt": """Extract structured data from this identity document. Output ONLY valid JSON with these exact fields, no nested objects: { "document_type": "", "full_name": "", "sex": "", "date_of_birth": "", "date_of_expiry": "", "nationality": "", "document_number": "", "place_of_birth": "", "personal_number": "" } IMPORTANT: Do NOT create nested or recursive structures. Keep it flat and simple.""", "icon": "๐Ÿ†”" }, DocumentType.RECEIPT: { "name": "Receipt", "description": "Extract items, prices, and totals from receipts", "prompt": """Extract information from this receipt. Output ONLY valid JSON: { "merchant_name": "", "date": "", "time": "", "items": [ {"name": "", "quantity": 1, "price": 0.0} ], "subtotal": 0.0, "tax": 0.0, "total": 0.0, "payment_method": "" }""", "icon": "๐Ÿงพ" }, DocumentType.INVOICE: { "name": "Invoice", "description": "Extract invoice details and line items", "prompt": """Extract information from this invoice. Output ONLY valid JSON: { "invoice_number": "", "date": "", "due_date": "", "vendor": { "name": "", "address": "", "contact": "" }, "customer": { "name": "", "address": "", "contact": "" }, "line_items": [ {"description": "", "quantity": 1, "unit_price": 0.0, "amount": 0.0} ], "subtotal": 0.0, "tax": 0.0, "total": 0.0 }""", "icon": "๐Ÿ“‹" }, DocumentType.BUSINESS_CARD: { "name": "Business Card", "description": "Extract contact information", "prompt": """Extract contact information from this business card. Output ONLY valid JSON: { "name": "", "title": "", "company": "", "email": "", "phone": "", "mobile": "", "website": "", "address": "", "social_media": {} }""", "icon": "๐Ÿ’ผ" }, DocumentType.FORM: { "name": "Form", "description": "Extract filled form data", "prompt": """Extract all fields and values from this form. Output ONLY valid JSON with field names as keys and filled values: { "field_name": "value" }""", "icon": "๐Ÿ“" }, DocumentType.HANDWRITTEN: { "name": "Handwritten Note", "description": "Extract text from handwritten documents", "prompt": "Extract all handwritten text from this image. Output plain text, preserving line breaks.", "icon": "โœ๏ธ" } } # ======================================== # MODEL LOADING # ======================================== @st.cache_resource def load_glm_ocr(): """Load GLM-OCR model (cached)""" MODEL_NAME = "zai-org/GLM-OCR" with st.spinner("๐Ÿ”„ Loading OCR model... (first time may take 1โ€“3 minutes)"): processor = AutoProcessor.from_pretrained( MODEL_NAME, trust_remote_code=True ) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 model = AutoModelForImageTextToText.from_pretrained( MODEL_NAME, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, low_cpu_mem_usage=True, trust_remote_code=True ) if not torch.cuda.is_available(): model = model.to(device) model.eval() return processor, model, device processor, model, device = load_glm_ocr() # ======================================== # IMAGE PREPROCESSING # ======================================== def preprocess_image( image: Image.Image, enhance_contrast: bool = False, denoise: bool = False, sharpen: bool = False, auto_rotate: bool = False, prevent_cropping: bool = False ) -> Image.Image: """ Preprocess image with optional enhancements Args: image: PIL Image enhance_contrast: Apply CLAHE contrast enhancement denoise: Apply denoising sharpen: Apply sharpening auto_rotate: Attempt to auto-rotate text to horizontal Returns: Preprocessed PIL Image """ if prevent_cropping and not auto_rotate: raise Exception(f"Auto-Rotate must be enabled when Prevent-Cropping is active") # Convert to OpenCV format img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) # Denoise if denoise: gray = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21) # Enhance contrast if enhance_contrast: clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) gray = clahe.apply(gray) # Sharpen if sharpen: kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]) gray = cv2.filter2D(gray, -1, kernel) # Auto-rotate (basic implementation) if auto_rotate: # blur for calculate the rotation angle correctly blurred = cv2.GaussianBlur(gray, (5, 5), 0) # Detect lines edges = cv2.Canny(blurred, 50, 150) lines = cv2.HoughLinesP(edges, 1, np.pi / 180, 100, minLineLength=80, maxLineGap=10) if lines is not None and len(lines) > 0: angles = [] for line in lines: x1, y1, x2, y2 = line[0] angle = np.degrees(np.arctan2(y2 - y1, x2 - x1)) if -45 < angle < 45: angles.append(angle) if len(angles) > 0: # Use median to be robust against outliers median_angle = np.median(angles) # If lines are detected near vertical, adjust (rare for text) if abs(median_angle) > 45: median_angle -= 90 * np.sign(median_angle) rotation_angle = median_angle # negative to correct back to horizontal (h, w) = gray.shape[:2] center = (w // 2, h // 2) M = cv2.getRotationMatrix2D(center, rotation_angle, 1.0) if prevent_cropping: # Calculate new image size to prevent cropping cos = np.abs(M[0, 0]) sin = np.abs(M[0, 1]) w = int((h * sin) + (w * cos)) h = int((h * cos) + (w * sin)) # Adjust transformation matrix M[0, 2] += (w / 2) - center[0] M[1, 2] += (h / 2) - center[1] gray = cv2.warpAffine(gray, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) # Convert back to RGB return Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)) # ======================================== # OCR EXTRACTION # ======================================== def extract_text( image: Image.Image, prompt: str, max_tokens: int = 2048 ) -> tuple[str, int]: """ Extract text from image using GLM-OCR Args: image: PIL Image prompt: Extraction prompt max_tokens: Maximum tokens to generate Returns: Tuple of (extracted_text, processing_time_ms) """ start_time = datetime.datetime.now() # Prepare messages messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt} ] } ] # Apply chat template inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ) # Move to device inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"} # Generate with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.0 ) # Decode output_text = processor.decode( generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True ) # Cleanup del inputs, generated_ids if torch.cuda.is_available(): torch.cuda.empty_cache() # Calculate processing time processing_time = (datetime.datetime.now() - start_time).total_seconds() * 1000 return output_text, int(processing_time) # ======================================== # STREAMLIT UI # ======================================== st.set_page_config( page_title="Universal OCR Scanner", page_icon="๐Ÿ”", layout="wide", initial_sidebar_state="expanded" ) # Initialize session state if 'should_process' not in st.session_state: st.session_state.should_process = False if 'has_results' not in st.session_state: st.session_state.has_results = False if 'output_text' not in st.session_state: st.session_state.output_text = "" if 'processing_time' not in st.session_state: st.session_state.processing_time = 0 if 'doc_type' not in st.session_state: st.session_state.doc_type = DocumentType.GENERAL if 'current_file' not in st.session_state: st.session_state.current_file = None # Header st.title("๐Ÿ” Universal OCR Scanner") st.markdown("Extract text and structured data from **any document** - receipts, IDs, invoices, forms, and more!") # Sidebar - Document Type Selection with st.sidebar: st.header("๐Ÿ“‹ Document Type") # Show document type cards doc_type = st.radio( "Select document type:", options=list(DocumentType), format_func=lambda x: f"{DOCUMENT_TEMPLATES[x]['icon']} {DOCUMENT_TEMPLATES[x]['name']}", label_visibility="collapsed" ) # Show description st.info(DOCUMENT_TEMPLATES[doc_type]['description']) st.markdown("---") # Preprocessing options st.header("โš™๏ธ Image Enhancement") with st.expander("๐ŸŽจ Preprocessing Options", expanded=False): enhance_contrast = st.checkbox("Enhance Contrast", value=False, help="Improve visibility of faded text") denoise = st.checkbox("Reduce Noise", value=False, help="Remove image noise and artifacts") sharpen = st.checkbox("Sharpen Text", value=False, help="Make text edges crisper") auto_rotate = st.checkbox("Auto-Rotate", value=False, help="Automatically straighten tilted documents") prevent_cropping = st.checkbox("Prevent-Cropping", value=False, help="Prevent cropping when rotate") st.markdown("---") # Advanced options with st.expander("๐Ÿ”ง Advanced Options", expanded=False): show_preprocessed = st.checkbox("Show Preprocessed Image", value=False) max_tokens = st.slider("Max Output Tokens", 512, 4096, 2048, 256, help="Increase for longer documents") custom_prompt = st.checkbox("Use Custom Prompt", value=False) st.markdown("---") # Info st.caption("๐Ÿ’ก **Tips:**") st.caption("โ€ข Use good lighting") st.caption("โ€ข Avoid shadows") st.caption("โ€ข Keep text horizontal") st.caption("โ€ข Use high resolution images") # Main content area col1, col2 = st.columns([1, 1]) with col1: st.subheader("๐Ÿ“ค Upload Document") # Tabs for upload methods upload_tab, camera_tab = st.tabs(["๐Ÿ“ Upload File", "๐Ÿ“ธ Take Photo"]) image = None with upload_tab: uploaded_file = st.file_uploader( "Choose an image...", type=["jpg", "jpeg", "png", "webp", "pdf"], help="Supported formats: JPG, PNG, WEBP, PDF" ) if uploaded_file is not None: file_extension = uploaded_file.name.split('.')[-1].lower() if file_extension == 'pdf': # open PDF from memory doc = fitz.open(stream=uploaded_file.read(), filetype="pdf") page = doc.load_page(0) # Choose First Page (0) # Render page to image (pixmap) pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better OCR result # Convert Pixmap to PIL Image Object img_data = pix.tobytes("png") image = Image.open(io.BytesIO(img_data)).convert("RGB") doc.close() else: # Normal image files image = Image.open(uploaded_file).convert("RGB") # Clear previous results when new image uploaded if 'current_file' not in st.session_state or st.session_state.current_file != uploaded_file.name: st.session_state.current_file = uploaded_file.name st.session_state.has_results = False with camera_tab: camera_picture = st.camera_input("Take a photo") if camera_picture is not None: image = Image.open(BytesIO(camera_picture.getvalue())).convert("RGB") # Clear previous results when new photo taken st.session_state.has_results = False # Show original image if image is not None: st.image(image, caption="Original Image", width="content") with col2: st.subheader("๐Ÿ“‹ Extraction Settings") # Show/edit prompt if custom_prompt: prompt = st.text_area( "Custom Extraction Prompt:", value=DOCUMENT_TEMPLATES[doc_type]['prompt'], height=200, help="Customize how the OCR extracts data", key="custom_prompt_text" ) else: prompt = DOCUMENT_TEMPLATES[doc_type]['prompt'] st.code(prompt, language="text") # Process button if image is not None: if st.button( "๐Ÿš€ Extract Text", type="primary", width="content", key="extract_button" ): # Trigger processing by setting session state st.session_state.should_process = True else: st.info("๐Ÿ‘† Upload or capture an image to begin") # Processing (only run when button is clicked) if image is not None and st.session_state.get('should_process', False): # Clear the flag immediately to prevent re-processing on next rerun st.session_state.should_process = False with st.spinner("๐Ÿ”„ Processing document..."): try: # Preprocess image if enhance_contrast or denoise or sharpen or auto_rotate or prevent_cropping: preprocessed_image = preprocess_image( image, enhance_contrast=enhance_contrast, denoise=denoise, sharpen=sharpen, auto_rotate=auto_rotate, prevent_cropping=prevent_cropping ) else: preprocessed_image = image # Show preprocessed if requested if show_preprocessed and preprocessed_image != image: st.subheader("๐Ÿ”ง Preprocessed Image") col_a, col_b = st.columns(2) with col_a: st.image(image, caption="Original", width="content") with col_b: st.image(preprocessed_image, caption="Enhanced", width="content") # Extract text output_text, processing_time = extract_text( preprocessed_image, prompt=prompt, max_tokens=max_tokens ) # Store results in session state st.session_state.output_text = output_text st.session_state.processing_time = processing_time st.session_state.doc_type = doc_type st.session_state.preprocessed_image = preprocessed_image st.session_state.has_results = True except Exception as e: st.error(f"โŒ Error during extraction: {str(e)}") import traceback with st.expander("Show Error Details"): st.code(traceback.format_exc()) st.session_state.has_results = False # Display results (separate from processing) if st.session_state.get('has_results', False): output_text = st.session_state.output_text processing_time = st.session_state.processing_time doc_type = st.session_state.doc_type preprocessed_image = st.session_state.get('preprocessed_image', image) # Display success message st.success(f"โœ… Extraction complete! ({processing_time}ms)") # Try to parse as JSON for structured documents is_json = False parsed_data = None if doc_type in [DocumentType.ID_CARD, DocumentType.RECEIPT, DocumentType.INVOICE, DocumentType.BUSINESS_CARD, DocumentType.FORM]: try: # Clean JSON from markdown clean_text = output_text if "```json" in clean_text: clean_text = clean_text.split("```json")[1].split("```")[0].strip() elif "```" in clean_text: clean_text = clean_text.split("```")[1].split("```")[0].strip() # Truncate if too long (likely recursive) if len(clean_text) > 50000: # Reasonable JSON should be much smaller st.warning("โš ๏ธ Detected recursive JSON structure. Truncating...") clean_text = clean_text[:50000] parsed_data = json.loads(clean_text) # Flatten recursive structures def flatten_dict(d, max_depth=2, current_depth=0): """Remove recursive nested structures""" if current_depth >= max_depth: return {} if not isinstance(d, dict): return d flattened = {} for key, value in d.items(): if isinstance(value, dict): # Only keep first level of nesting if current_depth < max_depth - 1: flattened[key] = flatten_dict(value, max_depth, current_depth + 1) # Skip deeply nested structures elif isinstance(value, list): # Keep lists but limit depth flattened[key] = value else: flattened[key] = value return flattened # Flatten the parsed data parsed_data = flatten_dict(parsed_data, max_depth=2) is_json = True except json.JSONDecodeError: is_json = False except Exception as e: st.warning(f"โš ๏ธ JSON parsing issue: {str(e)}") is_json = False # Display based on type st.markdown("---") st.subheader("๐Ÿ“„ Extracted Data") if is_json and parsed_data: # Structured data display col_display, col_download = st.columns([2, 1]) with col_display: # Format display based on document type if doc_type == DocumentType.RECEIPT: st.markdown("### ๐Ÿงพ Receipt Details") # Merchant info if "merchant_name" in parsed_data: st.markdown(f"**Merchant:** {parsed_data['merchant_name']}") if "date" in parsed_data: st.markdown(f"**Date:** {parsed_data['date']}") if "time" in parsed_data: st.markdown(f"**Time:** {parsed_data['time']}") # Items table if "items" in parsed_data and parsed_data["items"]: st.markdown("**Items:**") items_df = pd.DataFrame(parsed_data["items"]) st.dataframe(items_df, width="content", hide_index=True) # Totals st.markdown("---") if "subtotal" in parsed_data: st.markdown(f"**Subtotal:** ${parsed_data['subtotal']:.2f}") if "tax" in parsed_data: st.markdown(f"**Tax:** ${parsed_data['tax']:.2f}") if "total" in parsed_data: st.markdown(f"**Total:** ${parsed_data['total']:.2f}") elif doc_type == DocumentType.INVOICE: st.markdown("### ๐Ÿ“‹ Invoice Details") col_inv1, col_inv2 = st.columns(2) with col_inv1: st.markdown("**Invoice Info:**") if "invoice_number" in parsed_data: st.text(f"Number: {parsed_data['invoice_number']}") if "date" in parsed_data: st.text(f"Date: {parsed_data['date']}") if "due_date" in parsed_data: st.text(f"Due: {parsed_data['due_date']}") with col_inv2: if "vendor" in parsed_data: st.markdown("**Vendor:**") vendor = parsed_data["vendor"] if isinstance(vendor, dict): for k, v in vendor.items(): if v: st.text(f"{k.title()}: {v}") # Line items if "line_items" in parsed_data and parsed_data["line_items"]: st.markdown("**Line Items:**") items_df = pd.DataFrame(parsed_data["line_items"]) st.dataframe(items_df, width="content", hide_index=True) # Total if "total" in parsed_data: st.markdown(f"### **Total: ${parsed_data['total']:.2f}**") else: # Generic structured data display for key, value in parsed_data.items(): if isinstance(value, dict): st.markdown(f"**{key.replace('_', ' ').title()}:**") for k, v in value.items(): st.text(f" {k}: {v}") elif isinstance(value, list): st.markdown(f"**{key.replace('_', ' ').title()}:**") if value and isinstance(value[0], dict): df = pd.DataFrame(value) st.dataframe(df, width="content", hide_index=True) else: for item in value: st.text(f" โ€ข {item}") else: st.markdown(f"**{key.replace('_', ' ').title()}:** {value}") with col_download: st.subheader("๐Ÿ’พ Downloads") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") # JSON download json_str = json.dumps(parsed_data, ensure_ascii=False, indent=2) st.download_button( label="๐Ÿ“„ JSON", data=json_str, file_name=f"{doc_type.value}_{timestamp}.json", mime="application/json", width="content" ) # CSV download (flattened) try: # Flatten nested structures flat_data = {} for k, v in parsed_data.items(): if isinstance(v, (dict, list)): flat_data[k] = json.dumps(v, ensure_ascii=False) else: flat_data[k] = v df = pd.DataFrame([flat_data]) csv_buffer = StringIO() df.to_csv(csv_buffer, index=False, encoding='utf-8') st.download_button( label="๐Ÿ“Š CSV", data=csv_buffer.getvalue(), file_name=f"{doc_type.value}_{timestamp}.csv", mime="text/csv", width="content" ) except: pass # Raw text download st.download_button( label="๐Ÿ“ TXT", data=output_text, file_name=f"{doc_type.value}_{timestamp}.txt", mime="text/plain", width="content" ) # Show raw JSON in expander with st.expander("๐Ÿ” View Raw JSON"): st.json(parsed_data) else: # Plain text display st.text_area( "Extracted Text:", value=output_text, height=400, label_visibility="collapsed" ) # Download timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") st.download_button( label="๐Ÿ’พ Download as TXT", data=output_text, file_name=f"extracted_text_{timestamp}.txt", mime="text/plain" ) # Footer st.markdown("---") col_footer1, col_footer2, col_footer3 = st.columns(3) with col_footer1: st.caption("โšก Powered by GLM-OCR") with col_footer2: st.caption(f"๐Ÿ–ฅ๏ธ Device: {device.upper()}") with col_footer3: st.caption("๐ŸŒŸ Universal Document Scanner")