import spaces import streamlit as st from PIL import Image import torch from huggingface_hub import snapshot_download from transformers import ( AutoProcessor, AutoModelForImageTextToText, AutoModelForCausalLM, ) import json import cv2 import numpy as np import pandas as pd from io import BytesIO, StringIO import datetime from enum import Enum from typing import Optional, Tuple # ---------------------------- # Optional dependency (recommended for dots.ocr) # ---------------------------- try: from qwen_vl_utils import process_vision_info # type: ignore except Exception: process_vision_info = None # ======================================== # MODEL TYPES # ======================================== class OCRModel(str, Enum): GLM_OCR = "glm_ocr" DOTS_OCR = "dots_ocr" MODEL_UI = { OCRModel.GLM_OCR: {"name": "GLM-OCR", "icon": "🟦", "hf_id": "zai-org/GLM-OCR"}, OCRModel.DOTS_OCR: {"name": "dots.ocr", "icon": "🟩", "hf_id": "rednote-hilab/dots.ocr"}, } # ======================================== # DOCUMENT TYPES & TEMPLATES # ======================================== class DocumentType(str, Enum): """Supported document types""" GENERAL = "general" FULL_JSON_SCHEMA = "full_json_schema" SIMPLE_TITLE_JSON = "simple_title_json" LOCALIZED_TITLE_JSON = "localized_title_json" GROUNDED_TITLE_JSON = "grounded_title_json" HANDWRITTEN = "handwritten" DOCUMENT_TEMPLATES = { DocumentType.GENERAL: { "name": "General Text", "description": "Extract all text from any document", "prompt": "Extract all text from this image. Preserve the layout and structure. Output plain text.", "icon": "πŸ“„" }, DocumentType.FULL_JSON_SCHEMA: { "name": "Full Json Schema", "description": "Extract structured data from this cover page", "prompt": """Analyze this thesis/dissertation cover image and extract ONLY visible information. CRITICAL: Only extract information that is CLEARLY VISIBLE on the page. DO NOT invent, guess, or hallucinate any data. If a field is not visible, use null. Return ONLY valid JSON with this exact structure: { "title": "Main title of the thesis or dissertation as it appears on the title page", "subtitle": "Subtitle or remainder of the title, usually following a colon; null if not present", "author": "Full name of the author (student) who wrote the thesis or dissertation", "degree_type": "Academic degree sought by the author (e.g. PhD, Doctorate, Master's degree, Master's thesis)", "discipline": "Academic field or discipline of the thesis if explicitly stated; null if not present. Possible values: MathΓ©matiques|Physics|Biology|others", "granting_institution": "Institution where the thesis was submitted and the degree is granted (degree-granting institution)", "doctoral_school": "Doctoral school or graduate program, if explicitly mentioned; null if not present", "co_tutelle_institutions": "List of institutions involved in a joint supervision or co-tutelle agreement; empty list if none", "partner_institutions": "List of partner institutions associated with the thesis but not granting the degree; empty list if none", "defense_year": "Year the thesis or dissertation was defended, in YYYY format; null if not visible", "defense_place": "City or place where the defense took place, if stated; null if not present", "thesis_advisor": "Main thesis advisor or supervisor (director of thesis); full name; null if not present", "co_advisors": "List of co-advisors or co-supervisors if explicitly mentioned; full names; empty list if none", "jury_president": "President or chair of the thesis examination committee, if specified; null if not present", "reviewers": "List of reviewers or rapporteurs of the thesis, if specified; full names; empty list if none", "committee_members": "List of other thesis committee or jury members, excluding advisor and reviewers; full names; empty list if none", "language": "Language in which the thesis is written, if explicitly stated; null if not present", "confidence": "Confidence score between 0.0 and 1.0 indicating reliability of the extracted metadata" } IMPORTANT: Return null for any field where information is NOT clearly visible. Return ONLY the JSON, no explanation.""", "icon": "πŸ†”" }, DocumentType.SIMPLE_TITLE_JSON: { "name": "Simple Title Json", "description": "Extract title from this cover page", "prompt": """Extract the document title from this cover page. Output ONLY valid JSON: { "title": "" }""", "icon": "🧾" }, DocumentType.LOCALIZED_TITLE_JSON: { "name": "Localized Title Json", "description": "Extract localized title from this cover page", "prompt": """Extract the document title from the middle central block of this cover page. Output ONLY valid JSON: { "title": "" }""", "icon": "🧾" }, DocumentType.GROUNDED_TITLE_JSON: { "name": "Grounded Title Json", "description": "Extract localized title from this cover page", "prompt": """Extract the document title usually located around (0.5015,0.442) from this cover page. Output ONLY valid JSON: { "title": "" }""", "icon": "πŸ“‹" }, DocumentType.HANDWRITTEN: { "name": "Handwritten Note", "description": "Extract text from handwritten documents", "prompt": "Extract all handwritten text from this image. Output plain text, preserving line breaks.", "icon": "✍️" } } # ======================================== # MODEL LOADING # ======================================== @st.cache_resource def load_glm_ocr(): """Load GLM-OCR model (cached)""" model_name = MODEL_UI[OCRModel.GLM_OCR]["hf_id"] with st.spinner("πŸ”„ Loading GLM-OCR... (first time may take 1–3 minutes)"): processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 model = AutoModelForImageTextToText.from_pretrained( model_name, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, low_cpu_mem_usage=True, trust_remote_code=True ) if not torch.cuda.is_available(): model = model.to(device) model.eval() return processor, model, device @st.cache_resource def load_dots_ocr(): """Load dots.ocr model (cached)""" repo_id = MODEL_UI[OCRModel.DOTS_OCR]["hf_id"] # dots.ocr recommends avoiding '.' in local directory names (workaround mentioned in their docs) model_path = "./models/DotsOCR" snapshot_download( repo_id=repo_id, local_dir=model_path, local_dir_use_symlinks=False, ) with st.spinner("πŸ”„ Loading dots.ocr... (first time may take 1–3 minutes)"): device = "cuda" if torch.cuda.is_available() else "cpu" #dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 dtype = torch.float32 model = AutoModelForCausalLM.from_pretrained( model_path, # attn_implementation="flash_attention_2", # optional if available torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, low_cpu_mem_usage=True, ) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) if not torch.cuda.is_available(): model = model.to(device) model.eval() return processor, model, device def get_loaded_model(selected: OCRModel): """Return (processor, model, device) for selected model, cached by Streamlit.""" if selected == OCRModel.GLM_OCR: return load_glm_ocr() return load_dots_ocr() # ======================================== # IMAGE PREPROCESSING # ======================================== def preprocess_image( image: Image.Image, enhance_contrast: bool = False, denoise: bool = False, sharpen: bool = False, auto_rotate: bool = False, prevent_cropping: bool = False ) -> Image.Image: if prevent_cropping and not auto_rotate: raise Exception("Auto-Rotate must be enabled when Prevent-Cropping is active") img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) if denoise: gray = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21) if enhance_contrast: clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) gray = clahe.apply(gray) if sharpen: kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]) gray = cv2.filter2D(gray, -1, kernel) if auto_rotate: blurred = cv2.GaussianBlur(gray, (5, 5), 0) edges = cv2.Canny(blurred, 50, 150) lines = cv2.HoughLinesP(edges, 1, np.pi / 180, 100, minLineLength=80, maxLineGap=10) if lines is not None and len(lines) > 0: angles = [] for line in lines: x1, y1, x2, y2 = line[0] angle = np.degrees(np.arctan2(y2 - y1, x2 - x1)) if -45 < angle < 45: angles.append(angle) if angles: median_angle = float(np.median(angles)) if abs(median_angle) > 45: median_angle -= 90 * np.sign(median_angle) (h0, w0) = gray.shape[:2] center = (w0 // 2, h0 // 2) M = cv2.getRotationMatrix2D(center, median_angle, 1.0) out_w, out_h = w0, h0 if prevent_cropping: cos = np.abs(M[0, 0]) sin = np.abs(M[0, 1]) out_w = int((h0 * sin) + (w0 * cos)) out_h = int((h0 * cos) + (w0 * sin)) M[0, 2] += (out_w / 2) - center[0] M[1, 2] += (out_h / 2) - center[1] gray = cv2.warpAffine( gray, M, (out_w, out_h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE ) return Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)) # ======================================== # OCR EXTRACTION # ======================================== def _now_ms() -> int: return int(datetime.datetime.now().timestamp() * 1000) def extract_text_glm( image: Image.Image, prompt: str, max_tokens: int, processor, model, device: str ) -> Tuple[str, int]: start_ms = _now_ms() messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt} ] } ] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ) inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"} with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.0 ) output_text = processor.decode( generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True ) del inputs, generated_ids if torch.cuda.is_available(): torch.cuda.empty_cache() return output_text, _now_ms() - start_ms def extract_text_dots( image: Image.Image, prompt: str, max_tokens: int, processor, model, device: str ) -> Tuple[str, int]: """ dots.ocr transformers inference (matches their model-card approach): - apply_chat_template(tokenize=False) - process_vision_info(messages) to get image_inputs/video_inputs - processor(text=[...], images=..., videos=..., return_tensors="pt") - generate, then trim input tokens, then decode """ start_ms = _now_ms() # dots.ocr examples use {"type":"image","image": } but PIL works in practice with most processors messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt} ] } ] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) if process_vision_info is not None: image_inputs, video_inputs = process_vision_info(messages) else: # Fallback: no video, single image image_inputs, video_inputs = [image], None inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # Some processors return a BatchEncoding with .to(...) inputs = inputs.to(device) # some processors add keys that this model doesn't use unused_keys = ["mm_token_type_ids", "token_type_ids"] for k in unused_keys: if k in inputs: inputs.pop(k) with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.0 ) # Trim prompt tokens in_ids = inputs["input_ids"] trimmed = [] for i in range(generated_ids.shape[0]): trimmed.append(generated_ids[i][in_ids.shape[1]:]) output_text = processor.batch_decode( trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] del inputs, generated_ids, trimmed if torch.cuda.is_available(): torch.cuda.empty_cache() return output_text, _now_ms() - start_ms def extract_text( selected_model: OCRModel, image: Image.Image, prompt: str, max_tokens: int, ) -> Tuple[str, int]: processor, model, device = get_loaded_model(selected_model) if selected_model == OCRModel.GLM_OCR: return extract_text_glm(image, prompt, max_tokens, processor, model, device) return extract_text_dots(image, prompt, max_tokens, processor, model, device) # ======================================== # STREAMLIT UI # ======================================== st.set_page_config( page_title="Universal OCR Scanner", page_icon="πŸ”", layout="wide", initial_sidebar_state="expanded" ) # Initialize session state for k, v in { "should_process": False, "has_results": False, "output_text": "", "processing_time": 0, "doc_type": DocumentType.GENERAL, "selected_model": OCRModel.GLM_OCR, "current_file": None, }.items(): if k not in st.session_state: st.session_state[k] = v st.title("πŸ” Universal OCR Scanner") st.markdown("Extract text and structured data from **any document** - receipts, IDs, invoices, forms, and more!") with st.sidebar: st.header("🧠 Model") selected_model = st.radio( "Select OCR model:", options=list(OCRModel), format_func=lambda x: f"{MODEL_UI[x]['icon']} {MODEL_UI[x]['name']}", index=list(OCRModel).index(st.session_state.selected_model), ) st.session_state.selected_model = selected_model st.header("πŸ“‹ Document Type") doc_type = st.radio( "Select document type:", options=list(DocumentType), format_func=lambda x: f"{DOCUMENT_TEMPLATES[x]['icon']} {DOCUMENT_TEMPLATES[x]['name']}", index=list(DocumentType).index(st.session_state.doc_type), ) st.session_state.doc_type = doc_type st.info(DOCUMENT_TEMPLATES[doc_type]['description']) st.markdown("---") st.header("βš™οΈ Image Enhancement") with st.expander("🎨 Preprocessing Options", expanded=False): enhance_contrast = st.checkbox("Enhance Contrast", value=False, help="Improve visibility of faded text") denoise = st.checkbox("Reduce Noise", value=False, help="Remove image noise and artifacts") sharpen = st.checkbox("Sharpen Text", value=False, help="Make text edges crisper") auto_rotate = st.checkbox("Auto-Rotate", value=False, help="Automatically straighten tilted documents") prevent_cropping = st.checkbox("Prevent-Cropping", value=False, help="Prevent cropping when rotate") st.markdown("---") with st.expander("πŸ”§ Advanced Options", expanded=False): show_preprocessed = st.checkbox("Show Preprocessed Image", value=False) max_tokens = st.slider("Max Output Tokens", 512, 4096, 2048, 256, help="Increase for longer documents") custom_prompt = st.checkbox("Use Custom Prompt", value=False) st.markdown("---") st.caption("πŸ’‘ **Tips:**") st.caption("β€’ Use good lighting") st.caption("β€’ Avoid shadows") st.caption("β€’ Keep text horizontal") st.caption("β€’ Use high resolution images") col1, col2 = st.columns([1, 1]) with col1: st.subheader("πŸ“€ Upload Document") upload_tab, camera_tab = st.tabs(["πŸ“ Upload File", "πŸ“Έ Take Photo"]) image: Optional[Image.Image] = None with upload_tab: uploaded_file = st.file_uploader( "Choose an image...", type=["jpg", "jpeg", "png", "webp"], help="Supported formats: JPG, PNG, WEBP" ) if uploaded_file is not None: image = Image.open(uploaded_file).convert("RGB") if st.session_state.current_file != uploaded_file.name: st.session_state.current_file = uploaded_file.name st.session_state.has_results = False with camera_tab: camera_picture = st.camera_input("Take a photo") if camera_picture is not None: image = Image.open(BytesIO(camera_picture.getvalue())).convert("RGB") st.session_state.has_results = False if image is not None: st.image(image, caption="Original Image", width="content") with col2: st.subheader("πŸ“‹ Extraction Settings") if custom_prompt: prompt = st.text_area( "Custom Extraction Prompt:", value=DOCUMENT_TEMPLATES[doc_type]['prompt'], height=200, help="Customize how the OCR extracts data", key="custom_prompt_text" ) else: prompt = DOCUMENT_TEMPLATES[doc_type]['prompt'] st.code(prompt, language="text") if image is not None: if st.button("πŸš€ Extract Text", type="primary", width="content", key="extract_button"): st.session_state.should_process = True else: st.info("πŸ‘† Upload or capture an image to begin") # Processing if image is not None and st.session_state.get('should_process', False): st.session_state.should_process = False with st.spinner("πŸ”„ Processing document..."): try: if enhance_contrast or denoise or sharpen or auto_rotate or prevent_cropping: preprocessed_image = preprocess_image( image, enhance_contrast=enhance_contrast, denoise=denoise, sharpen=sharpen, auto_rotate=auto_rotate, prevent_cropping=prevent_cropping ) else: preprocessed_image = image if show_preprocessed and preprocessed_image != image: st.subheader("πŸ”§ Preprocessed Image") col_a, col_b = st.columns(2) with col_a: st.image(image, caption="Original", width="content") with col_b: st.image(preprocessed_image, caption="Enhanced", width="content") output_text, processing_time = extract_text( selected_model=st.session_state.selected_model, image=preprocessed_image, prompt=prompt, max_tokens=max_tokens ) st.session_state.output_text = output_text st.session_state.processing_time = processing_time st.session_state.preprocessed_image = preprocessed_image st.session_state.has_results = True except Exception as e: st.error(f"❌ Error during extraction: {str(e)}") import traceback with st.expander("Show Error Details"): st.code(traceback.format_exc()) st.session_state.has_results = False # Results if st.session_state.get('has_results', False): output_text = st.session_state.output_text processing_time = st.session_state.processing_time preprocessed_image = st.session_state.get('preprocessed_image', image) st.success(f"βœ… Extraction complete! ({processing_time}ms)") is_json = False parsed_data = None if doc_type in [ DocumentType.FULL_JSON_SCHEMA, DocumentType.SIMPLE_TITLE_JSON, DocumentType.LOCALIZED_TITLE_JSON, DocumentType.GROUNDED_TITLE_JSON ]: try: clean_text = output_text if "```json" in clean_text: clean_text = clean_text.split("```json")[1].split("```")[0].strip() elif "```" in clean_text: clean_text = clean_text.split("```")[1].split("```")[0].strip() if len(clean_text) > 50000: st.warning("⚠️ Detected unusually large JSON output. Truncating...") clean_text = clean_text[:50000] parsed_data = json.loads(clean_text) def flatten_dict(d, max_depth=2, current_depth=0): if current_depth >= max_depth: return {} if not isinstance(d, dict): return d flattened = {} for key, value in d.items(): if isinstance(value, dict): if current_depth < max_depth - 1: flattened[key] = flatten_dict(value, max_depth, current_depth + 1) elif isinstance(value, list): flattened[key] = value else: flattened[key] = value return flattened parsed_data = flatten_dict(parsed_data, max_depth=2) is_json = True except json.JSONDecodeError: is_json = False except Exception as e: st.warning(f"⚠️ JSON parsing issue: {str(e)}") is_json = False st.markdown("---") st.subheader("πŸ“„ Extracted Data") if is_json and parsed_data: col_display, col_download = st.columns([2, 1]) with col_display: for key, value in parsed_data.items(): if isinstance(value, dict): st.markdown(f"**{key.replace('_', ' ').title()}:**") for k, v in value.items(): st.text(f" {k}: {v}") elif isinstance(value, list): st.markdown(f"**{key.replace('_', ' ').title()}:**") if value and isinstance(value[0], dict): df = pd.DataFrame(value) st.dataframe(df, width="content", hide_index=True) else: for item in value: st.text(f" β€’ {item}") else: st.markdown(f"**{key.replace('_', ' ').title()}:** {value}") with col_download: st.subheader("πŸ’Ύ Downloads") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") json_str = json.dumps(parsed_data, ensure_ascii=False, indent=2) st.download_button( label="πŸ“„ JSON", data=json_str, file_name=f"{doc_type.value}_{timestamp}.json", mime="application/json", width="content" ) try: flat_data = {} for k, v in parsed_data.items(): if isinstance(v, (dict, list)): flat_data[k] = json.dumps(v, ensure_ascii=False) else: flat_data[k] = v df = pd.DataFrame([flat_data]) csv_buffer = StringIO() df.to_csv(csv_buffer, index=False, encoding='utf-8') st.download_button( label="πŸ“Š CSV", data=csv_buffer.getvalue(), file_name=f"{doc_type.value}_{timestamp}.csv", mime="text/csv", width="content" ) except Exception: pass st.download_button( label="πŸ“ TXT", data=output_text, file_name=f"{doc_type.value}_{timestamp}.txt", mime="text/plain", width="content" ) with st.expander("πŸ” View Raw JSON"): st.json(parsed_data) else: st.text_area( "Extracted Text:", value=output_text, height=400, label_visibility="collapsed" ) timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") st.download_button( label="πŸ’Ύ Download as TXT", data=output_text, file_name=f"extracted_text_{timestamp}.txt", mime="text/plain" ) # Footer st.markdown("---") col_footer1, col_footer2, col_footer3 = st.columns(3) with col_footer1: m = st.session_state.selected_model st.caption(f"⚑ Powered by {MODEL_UI[m]['name']}") with col_footer2: # device depends on selected model (they’ll both use same device typically) _, _, device = get_loaded_model(st.session_state.selected_model) st.caption(f"πŸ–₯️ Device: {device.upper()}") with col_footer3: st.caption("🌟 Universal Document Scanner")