import os import glob import json import time import torch import gradio as gr from PIL import Image from transformers import AutoProcessor, AutoModelForImageTextToText # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- MODEL_PATH = "zai-org/GLM-OCR" EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "examples") MAX_NEW_TOKENS = 8192 # --------------------------------------------------------------------------- # Model loading (once at startup) # --------------------------------------------------------------------------- print("=" * 60) print(" Loading GLM-OCR model...") print("=" * 60) processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) model = AutoModelForImageTextToText.from_pretrained( pretrained_model_name_or_path=MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) model.eval() print("✅ GLM-OCR model loaded successfully!\n") # --------------------------------------------------------------------------- # Prompt templates # --------------------------------------------------------------------------- TASK_PROMPTS = { "Text Recognition": "Text Recognition:", "Formula Recognition": "Formula Recognition:", "Table Recognition": "Table Recognition:", } # Pre-built extraction schemas for the demo EXTRACTION_TEMPLATES = { "Custom (write your own)": "", "ID Card (English)": json.dumps( { "id_number": "", "last_name": "", "first_name": "", "date_of_birth": "", "address": { "street": "", "city": "", "state": "", "zip_code": "", }, "dates": {"issue_date": "", "expiration_date": ""}, "sex": "", }, indent=2, ), "Invoice": json.dumps( { "invoice_number": "", "date": "", "vendor": "", "items": [{"description": "", "quantity": "", "unit_price": "", "amount": ""}], "subtotal": "", "tax": "", "total": "", }, indent=2, ), "Business Card": json.dumps( { "name": "", "title": "", "company": "", "phone": "", "email": "", "address": "", "website": "", }, indent=2, ), "Receipt": json.dumps( { "store_name": "", "date": "", "items": [{"name": "", "price": ""}], "subtotal": "", "tax": "", "total": "", "payment_method": "", }, indent=2, ), } # --------------------------------------------------------------------------- # Inference helpers # --------------------------------------------------------------------------- def run_ocr(image: Image.Image, prompt_text: str) -> tuple[str, float]: """Run GLM-OCR inference and return (output_text, elapsed_seconds).""" if image is None: raise gr.Error("Please upload an image first.") messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt_text}, ], } ] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ).to(model.device) inputs.pop("token_type_ids", None) t0 = time.perf_counter() with torch.no_grad(): generated_ids = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS) elapsed = time.perf_counter() - t0 output_text = processor.decode( generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False, ) # Strip the end-of-sequence tokens for cleaner output for tok in ["<|endoftext|>", "", "<|im_end|>", "<|end|>"]: output_text = output_text.replace(tok, "") return output_text.strip(), elapsed # --------------------------------------------------------------------------- # Tab handlers # --------------------------------------------------------------------------- def handle_document_parsing(image, task_name): """Handle document parsing tasks (text, formula, table).""" prompt = TASK_PROMPTS[task_name] result, elapsed = run_ocr(image, prompt) stats = f"⏱ {elapsed:.2f}s" return result, stats def handle_info_extraction(image, template_name, custom_schema): """Handle information extraction with a JSON schema prompt.""" if template_name == "Custom (write your own)": schema_text = custom_schema else: schema_text = EXTRACTION_TEMPLATES[template_name] if not schema_text or schema_text.strip() == "": raise gr.Error("Please provide a JSON schema for extraction.") prompt = f"请按下列JSON格式输出图中信息:\n{schema_text}" result, elapsed = run_ocr(image, prompt) stats = f"⏱ {elapsed:.2f}s" return result, stats def handle_custom_prompt(image, custom_prompt): """Handle a completely custom prompt entered by the user.""" if not custom_prompt or custom_prompt.strip() == "": raise gr.Error("Please enter a prompt.") result, elapsed = run_ocr(image, custom_prompt.strip()) stats = f"⏱ {elapsed:.2f}s" return result, stats # --------------------------------------------------------------------------- # UI helpers # --------------------------------------------------------------------------- def toggle_custom_schema(template_name): """Show/hide the custom schema textbox based on template selection.""" return gr.update(visible=(template_name == "Custom (write your own)")) def get_example_images(): """Return list of example image paths if they exist.""" if not os.path.isdir(EXAMPLES_DIR): return [] exts = ["*.png", "*.jpg", "*.jpeg", "*.webp", "*.bmp"] paths = [] for ext in exts: paths.extend(glob.glob(os.path.join(EXAMPLES_DIR, ext))) return sorted(paths) # --------------------------------------------------------------------------- # Custom CSS for a polished, unique look # --------------------------------------------------------------------------- CUSTOM_CSS = """ /* Global */ .gradio-container { max-width: 1280px !important; margin: auto; } /* Header banner */ #header-banner { background: linear-gradient(135deg, #0f172a 0%, #1e3a5f 50%, #0ea5e9 100%); border-radius: 16px; padding: 28px 36px; margin-bottom: 16px; color: white; text-align: center; } #header-banner h1 { font-size: 2.2rem; font-weight: 800; margin: 0 0 4px 0; letter-spacing: -0.5px; } #header-banner p { margin: 4px 0 0 0; opacity: 0.85; font-size: 1rem; } /* Stat badges */ .stat-badge { display: inline-block; background: rgba(14, 165, 233, 0.15); border: 1px solid rgba(14, 165, 233, 0.3); border-radius: 8px; padding: 4px 14px; font-size: 0.92rem; color: #0ea5e9; font-weight: 600; } /* Task cards */ .task-card { border: 1px solid #e2e8f0; border-radius: 12px; padding: 16px; transition: box-shadow 0.2s; } .task-card:hover { box-shadow: 0 4px 16px rgba(14,165,233,0.10); } /* Run button */ #run-btn, #run-btn-extract, #run-btn-custom { background: linear-gradient(135deg, #0ea5e9, #2563eb) !important; color: white !important; font-weight: 700 !important; font-size: 1.05rem !important; border-radius: 10px !important; padding: 12px 0px !important; border: none !important; transition: transform 0.15s, box-shadow 0.15s !important; } #run-btn:hover, #run-btn-extract:hover, #run-btn-custom:hover { transform: translateY(-1px) !important; box-shadow: 0 6px 20px rgba(14,165,233,0.25) !important; } /* Output text */ #output-text textarea, #output-text-extract textarea, #output-text-custom textarea { font-family: 'Cascadia Code', 'Fira Code', 'Consolas', monospace !important; font-size: 0.92rem !important; line-height: 1.6 !important; } /* Stats label */ .stats-label { font-size: 0.9rem; color: #64748b; font-weight: 500; } /* Footer */ #footer-info { text-align: center; padding: 12px; color: #94a3b8; font-size: 0.85rem; } """ # --------------------------------------------------------------------------- # Build the Gradio app # --------------------------------------------------------------------------- with gr.Blocks( title="GLM-OCR Studio", theme=gr.themes.Base( primary_hue=gr.themes.colors.sky, secondary_hue=gr.themes.colors.blue, neutral_hue=gr.themes.colors.slate, font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"], font_mono=[gr.themes.GoogleFont("Fira Code"), "Consolas", "monospace"], ), css=CUSTOM_CSS, ) as demo: # ---- Header ---- gr.HTML( """

GLM-OCR Studio

Powered by zai-org/GLM-OCR — 0.9B parameters, state-of-the-art document understanding

📝 Text Recognition 📐 Formula Recognition 📊 Table Recognition 🔍 Information Extraction 💬 Custom Prompts
""", ) # ---- Main Tabs ---- with gr.Tabs() as main_tabs: # ============================================================ # TAB 1: Document Parsing # ============================================================ with gr.Tab("📄 Document Parsing", id="tab-parse"): gr.Markdown( "Upload an image and choose a parsing mode. The model will extract " "text, formulas, or tables from your document." ) with gr.Row(equal_height=True): # -- Left column: inputs -- with gr.Column(scale=2, min_width=320): parse_image = gr.Image( type="pil", label="Upload Document Image", sources=["upload", "clipboard"], height=380, ) parse_task = gr.Radio( choices=list(TASK_PROMPTS.keys()), value="Text Recognition", label="Parsing Mode", info="Select what to extract from the image", ) parse_btn = gr.Button( "🚀 Run OCR", variant="primary", elem_id="run-btn", size="lg", ) # -- Right column: outputs -- with gr.Column(scale=3, min_width=400): parse_stats = gr.Markdown("", elem_classes=["stats-label"]) parse_output = gr.Textbox( label="Recognition Result", lines=18, max_lines=40, elem_id="output-text", interactive=False, ) parse_btn.click( fn=handle_document_parsing, inputs=[parse_image, parse_task], outputs=[parse_output, parse_stats], ) # -- Examples -- example_paths = get_example_images() if example_paths: gr.Markdown("### 📸 Try an Example") gr.Examples( examples=[ [p, "Text Recognition"] for p in example_paths[:3] ] + [ [p, "Table Recognition"] for p in example_paths if "table" in os.path.basename(p).lower() ] + [ [p, "Formula Recognition"] for p in example_paths if "formula" in os.path.basename(p).lower() ], inputs=[parse_image, parse_task], label="Click an example to load it", cache_examples=False, ) # ============================================================ # TAB 2: Information Extraction # ============================================================ with gr.Tab("🔍 Information Extraction", id="tab-extract"): gr.Markdown( "Extract structured data from documents using a JSON schema. " "Choose a pre-built template or write your own schema." ) with gr.Row(equal_height=True): with gr.Column(scale=2, min_width=320): extract_image = gr.Image( type="pil", label="Upload Document Image", sources=["upload", "clipboard"], height=300, ) extract_template = gr.Dropdown( choices=list(EXTRACTION_TEMPLATES.keys()), value="Receipt", label="Extraction Template", info="Pre-built JSON schemas for common document types", ) extract_custom_schema = gr.Code( label="Custom JSON Schema", language="json", lines=10, visible=False, value='{\n "field_1": "",\n "field_2": ""\n}', ) extract_btn = gr.Button( "🔍 Extract Information", variant="primary", elem_id="run-btn-extract", size="lg", ) with gr.Column(scale=3, min_width=400): extract_stats = gr.Markdown("", elem_classes=["stats-label"]) extract_output = gr.Textbox( label="Extraction Result", lines=18, max_lines=40, elem_id="output-text-extract", interactive=False, ) extract_template.change( fn=toggle_custom_schema, inputs=extract_template, outputs=extract_custom_schema, ) extract_btn.click( fn=handle_info_extraction, inputs=[extract_image, extract_template, extract_custom_schema], outputs=[extract_output, extract_stats], ) if example_paths: gr.Markdown("### 📸 Try an Example") gr.Examples( examples=[ [p, "Receipt", ""] for p in example_paths if "receipt" in os.path.basename(p).lower() ] + [ [p, "Custom (write your own)", ""] for p in example_paths[:1] ], inputs=[extract_image, extract_template, extract_custom_schema], label="Click an example to load it", cache_examples=False, ) # ============================================================ # TAB 3: Custom Prompt # ============================================================ with gr.Tab("💬 Custom Prompt", id="tab-custom"): gr.Markdown( "Send any custom prompt to the model along with an image. " "Great for experimenting with different instructions." ) with gr.Row(equal_height=True): with gr.Column(scale=2, min_width=320): custom_image = gr.Image( type="pil", label="Upload Image", sources=["upload", "clipboard"], height=300, ) custom_prompt = gr.Textbox( label="Your Prompt", placeholder="e.g., Describe the contents of this image in detail...", lines=4, ) with gr.Accordion("💡 Prompt Ideas", open=False): gr.Markdown( """ **Document Parsing prompts:** - `Text Recognition:` — extract all text - `Formula Recognition:` — extract LaTeX formulas - `Table Recognition:` — parse tables **Information Extraction (use JSON schema):** ``` 请按下列JSON格式输出图中信息: { "name": "", "date": "", "total": "" } ``` **Tips:** - Be specific about what you want extracted - For structured output, provide a JSON template - The model works best with clear, direct instructions """ ) custom_btn = gr.Button( "▶ Run", variant="primary", elem_id="run-btn-custom", size="lg", ) with gr.Column(scale=3, min_width=400): custom_stats = gr.Markdown("", elem_classes=["stats-label"]) custom_output = gr.Textbox( label="Model Output", lines=18, max_lines=40, elem_id="output-text-custom", interactive=False, ) custom_btn.click( fn=handle_custom_prompt, inputs=[custom_image, custom_prompt], outputs=[custom_output, custom_stats], ) # ============================================================ # TAB 4: About # ============================================================ with gr.Tab("ℹ️ About", id="tab-about"): gr.Markdown( """ ## About GLM-OCR **GLM-OCR** is a multimodal OCR model for complex document understanding, built on the GLM-V encoder–decoder architecture. It combines: - **CogViT** visual encoder pre-trained on large-scale image–text data - A lightweight cross-modal connector with efficient token downsampling - A **GLM-0.5B** language decoder - Multi-Token Prediction (MTP) loss and stable full-task reinforcement learning ### Key Features | Feature | Description | |---------|-------------| | **#1 on OmniDocBench V1.5** | Score of 94.62, state-of-the-art across document understanding benchmarks | | **0.9B Parameters** | Efficient inference, ideal for production and edge deployment | | **Multi-format** | Handles text, formulas, tables, code, seals, and complex layouts | | **Multiple backends** | Supports vLLM, SGLang, Ollama, and Transformers | ### Supported Tasks 1. **Text Recognition** — Extract raw text from documents and images 2. **Formula Recognition** — Convert mathematical formulas to LaTeX 3. **Table Recognition** — Parse tables into structured HTML/Markdown 4. **Information Extraction** — Extract structured JSON from documents using custom schemas ### Links - 🏠 [Model Card on Hugging Face](https://huggingface.co/zai-org/GLM-OCR) - 📦 [Official SDK on GitHub](https://github.com/zai-org/GLM-OCR) - 📄 License: **MIT** --- *Built with [Gradio](https://gradio.app) and [Transformers](https://huggingface.co/docs/transformers).* """ ) # ---- Footer ---- gr.HTML( '" ) # --------------------------------------------------------------------------- # Launch # --------------------------------------------------------------------------- if __name__ == "__main__": # Auto-download example images if the examples dir doesn't exist if not os.path.isdir(EXAMPLES_DIR) or len(os.listdir(EXAMPLES_DIR)) == 0: print("📥 Downloading example images...") from download_examples import download_examples download_examples() demo.queue(max_size=10).launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, )