Spaces:
Running
Running
| import os | |
| import glob | |
| import json | |
| import time | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| MODEL_PATH = "zai-org/GLM-OCR" | |
| EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "examples") | |
| MAX_NEW_TOKENS = 8192 | |
| # --------------------------------------------------------------------------- | |
| # Model loading (once at startup) | |
| # --------------------------------------------------------------------------- | |
| print("=" * 60) | |
| print(" Loading GLM-OCR model...") | |
| print("=" * 60) | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| pretrained_model_name_or_path=MODEL_PATH, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| print("✅ GLM-OCR model loaded successfully!\n") | |
| # --------------------------------------------------------------------------- | |
| # Prompt templates | |
| # --------------------------------------------------------------------------- | |
| TASK_PROMPTS = { | |
| "Text Recognition": "Text Recognition:", | |
| "Formula Recognition": "Formula Recognition:", | |
| "Table Recognition": "Table Recognition:", | |
| } | |
| # Pre-built extraction schemas for the demo | |
| EXTRACTION_TEMPLATES = { | |
| "Custom (write your own)": "", | |
| "ID Card (English)": json.dumps( | |
| { | |
| "id_number": "", | |
| "last_name": "", | |
| "first_name": "", | |
| "date_of_birth": "", | |
| "address": { | |
| "street": "", | |
| "city": "", | |
| "state": "", | |
| "zip_code": "", | |
| }, | |
| "dates": {"issue_date": "", "expiration_date": ""}, | |
| "sex": "", | |
| }, | |
| indent=2, | |
| ), | |
| "Invoice": json.dumps( | |
| { | |
| "invoice_number": "", | |
| "date": "", | |
| "vendor": "", | |
| "items": [{"description": "", "quantity": "", "unit_price": "", "amount": ""}], | |
| "subtotal": "", | |
| "tax": "", | |
| "total": "", | |
| }, | |
| indent=2, | |
| ), | |
| "Business Card": json.dumps( | |
| { | |
| "name": "", | |
| "title": "", | |
| "company": "", | |
| "phone": "", | |
| "email": "", | |
| "address": "", | |
| "website": "", | |
| }, | |
| indent=2, | |
| ), | |
| "Receipt": json.dumps( | |
| { | |
| "store_name": "", | |
| "date": "", | |
| "items": [{"name": "", "price": ""}], | |
| "subtotal": "", | |
| "tax": "", | |
| "total": "", | |
| "payment_method": "", | |
| }, | |
| indent=2, | |
| ), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Inference helpers | |
| # --------------------------------------------------------------------------- | |
| def run_ocr(image: Image.Image, prompt_text: str) -> tuple[str, float]: | |
| """Run GLM-OCR inference and return (output_text, elapsed_seconds).""" | |
| if image is None: | |
| raise gr.Error("Please upload an image first.") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt_text}, | |
| ], | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| inputs.pop("token_type_ids", None) | |
| t0 = time.perf_counter() | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS) | |
| elapsed = time.perf_counter() - t0 | |
| output_text = processor.decode( | |
| generated_ids[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=False, | |
| ) | |
| # Strip the end-of-sequence tokens for cleaner output | |
| for tok in ["<|endoftext|>", "</s>", "<|im_end|>", "<|end|>"]: | |
| output_text = output_text.replace(tok, "") | |
| return output_text.strip(), elapsed | |
| # --------------------------------------------------------------------------- | |
| # Tab handlers | |
| # --------------------------------------------------------------------------- | |
| def handle_document_parsing(image, task_name): | |
| """Handle document parsing tasks (text, formula, table).""" | |
| prompt = TASK_PROMPTS[task_name] | |
| result, elapsed = run_ocr(image, prompt) | |
| stats = f"⏱ {elapsed:.2f}s" | |
| return result, stats | |
| def handle_info_extraction(image, template_name, custom_schema): | |
| """Handle information extraction with a JSON schema prompt.""" | |
| if template_name == "Custom (write your own)": | |
| schema_text = custom_schema | |
| else: | |
| schema_text = EXTRACTION_TEMPLATES[template_name] | |
| if not schema_text or schema_text.strip() == "": | |
| raise gr.Error("Please provide a JSON schema for extraction.") | |
| prompt = f"请按下列JSON格式输出图中信息:\n{schema_text}" | |
| result, elapsed = run_ocr(image, prompt) | |
| stats = f"⏱ {elapsed:.2f}s" | |
| return result, stats | |
| def handle_custom_prompt(image, custom_prompt): | |
| """Handle a completely custom prompt entered by the user.""" | |
| if not custom_prompt or custom_prompt.strip() == "": | |
| raise gr.Error("Please enter a prompt.") | |
| result, elapsed = run_ocr(image, custom_prompt.strip()) | |
| stats = f"⏱ {elapsed:.2f}s" | |
| return result, stats | |
| # --------------------------------------------------------------------------- | |
| # UI helpers | |
| # --------------------------------------------------------------------------- | |
| def toggle_custom_schema(template_name): | |
| """Show/hide the custom schema textbox based on template selection.""" | |
| return gr.update(visible=(template_name == "Custom (write your own)")) | |
| def get_example_images(): | |
| """Return list of example image paths if they exist.""" | |
| if not os.path.isdir(EXAMPLES_DIR): | |
| return [] | |
| exts = ["*.png", "*.jpg", "*.jpeg", "*.webp", "*.bmp"] | |
| paths = [] | |
| for ext in exts: | |
| paths.extend(glob.glob(os.path.join(EXAMPLES_DIR, ext))) | |
| return sorted(paths) | |
| # --------------------------------------------------------------------------- | |
| # Custom CSS for a polished, unique look | |
| # --------------------------------------------------------------------------- | |
| CUSTOM_CSS = """ | |
| /* Global */ | |
| .gradio-container { | |
| max-width: 1280px !important; | |
| margin: auto; | |
| } | |
| /* Header banner */ | |
| #header-banner { | |
| background: linear-gradient(135deg, #0f172a 0%, #1e3a5f 50%, #0ea5e9 100%); | |
| border-radius: 16px; | |
| padding: 28px 36px; | |
| margin-bottom: 16px; | |
| color: white; | |
| text-align: center; | |
| } | |
| #header-banner h1 { | |
| font-size: 2.2rem; | |
| font-weight: 800; | |
| margin: 0 0 4px 0; | |
| letter-spacing: -0.5px; | |
| } | |
| #header-banner p { | |
| margin: 4px 0 0 0; | |
| opacity: 0.85; | |
| font-size: 1rem; | |
| } | |
| /* Stat badges */ | |
| .stat-badge { | |
| display: inline-block; | |
| background: rgba(14, 165, 233, 0.15); | |
| border: 1px solid rgba(14, 165, 233, 0.3); | |
| border-radius: 8px; | |
| padding: 4px 14px; | |
| font-size: 0.92rem; | |
| color: #0ea5e9; | |
| font-weight: 600; | |
| } | |
| /* Task cards */ | |
| .task-card { | |
| border: 1px solid #e2e8f0; | |
| border-radius: 12px; | |
| padding: 16px; | |
| transition: box-shadow 0.2s; | |
| } | |
| .task-card:hover { | |
| box-shadow: 0 4px 16px rgba(14,165,233,0.10); | |
| } | |
| /* Run button */ | |
| #run-btn, #run-btn-extract, #run-btn-custom { | |
| background: linear-gradient(135deg, #0ea5e9, #2563eb) !important; | |
| color: white !important; | |
| font-weight: 700 !important; | |
| font-size: 1.05rem !important; | |
| border-radius: 10px !important; | |
| padding: 12px 0px !important; | |
| border: none !important; | |
| transition: transform 0.15s, box-shadow 0.15s !important; | |
| } | |
| #run-btn:hover, #run-btn-extract:hover, #run-btn-custom:hover { | |
| transform: translateY(-1px) !important; | |
| box-shadow: 0 6px 20px rgba(14,165,233,0.25) !important; | |
| } | |
| /* Output text */ | |
| #output-text textarea, #output-text-extract textarea, #output-text-custom textarea { | |
| font-family: 'Cascadia Code', 'Fira Code', 'Consolas', monospace !important; | |
| font-size: 0.92rem !important; | |
| line-height: 1.6 !important; | |
| } | |
| /* Stats label */ | |
| .stats-label { | |
| font-size: 0.9rem; | |
| color: #64748b; | |
| font-weight: 500; | |
| } | |
| /* Footer */ | |
| #footer-info { | |
| text-align: center; | |
| padding: 12px; | |
| color: #94a3b8; | |
| font-size: 0.85rem; | |
| } | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Build the Gradio app | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks( | |
| title="GLM-OCR Studio", | |
| theme=gr.themes.Base( | |
| primary_hue=gr.themes.colors.sky, | |
| secondary_hue=gr.themes.colors.blue, | |
| neutral_hue=gr.themes.colors.slate, | |
| font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"], | |
| font_mono=[gr.themes.GoogleFont("Fira Code"), "Consolas", "monospace"], | |
| ), | |
| css=CUSTOM_CSS, | |
| ) as demo: | |
| # ---- Header ---- | |
| gr.HTML( | |
| """ | |
| <div id="header-banner"> | |
| <h1>GLM-OCR Studio</h1> | |
| <p>Powered by <strong>zai-org/GLM-OCR</strong> — 0.9B parameters, state-of-the-art document understanding</p> | |
| <div style="margin-top:12px; display:flex; gap:12px; justify-content:center; flex-wrap:wrap;"> | |
| <span class="stat-badge">📝 Text Recognition</span> | |
| <span class="stat-badge">📐 Formula Recognition</span> | |
| <span class="stat-badge">📊 Table Recognition</span> | |
| <span class="stat-badge">🔍 Information Extraction</span> | |
| <span class="stat-badge">💬 Custom Prompts</span> | |
| </div> | |
| </div> | |
| """, | |
| ) | |
| # ---- Main Tabs ---- | |
| with gr.Tabs() as main_tabs: | |
| # ============================================================ | |
| # TAB 1: Document Parsing | |
| # ============================================================ | |
| with gr.Tab("📄 Document Parsing", id="tab-parse"): | |
| gr.Markdown( | |
| "Upload an image and choose a parsing mode. The model will extract " | |
| "text, formulas, or tables from your document." | |
| ) | |
| with gr.Row(equal_height=True): | |
| # -- Left column: inputs -- | |
| with gr.Column(scale=2, min_width=320): | |
| parse_image = gr.Image( | |
| type="pil", | |
| label="Upload Document Image", | |
| sources=["upload", "clipboard"], | |
| height=380, | |
| ) | |
| parse_task = gr.Radio( | |
| choices=list(TASK_PROMPTS.keys()), | |
| value="Text Recognition", | |
| label="Parsing Mode", | |
| info="Select what to extract from the image", | |
| ) | |
| parse_btn = gr.Button( | |
| "🚀 Run OCR", | |
| variant="primary", | |
| elem_id="run-btn", | |
| size="lg", | |
| ) | |
| # -- Right column: outputs -- | |
| with gr.Column(scale=3, min_width=400): | |
| parse_stats = gr.Markdown("", elem_classes=["stats-label"]) | |
| parse_output = gr.Textbox( | |
| label="Recognition Result", | |
| lines=18, | |
| max_lines=40, | |
| elem_id="output-text", | |
| interactive=False, | |
| ) | |
| parse_btn.click( | |
| fn=handle_document_parsing, | |
| inputs=[parse_image, parse_task], | |
| outputs=[parse_output, parse_stats], | |
| ) | |
| # -- Examples -- | |
| example_paths = get_example_images() | |
| if example_paths: | |
| gr.Markdown("### 📸 Try an Example") | |
| gr.Examples( | |
| examples=[ | |
| [p, "Text Recognition"] | |
| for p in example_paths[:3] | |
| ] + [ | |
| [p, "Table Recognition"] | |
| for p in example_paths | |
| if "table" in os.path.basename(p).lower() | |
| ] + [ | |
| [p, "Formula Recognition"] | |
| for p in example_paths | |
| if "formula" in os.path.basename(p).lower() | |
| ], | |
| inputs=[parse_image, parse_task], | |
| label="Click an example to load it", | |
| cache_examples=False, | |
| ) | |
| # ============================================================ | |
| # TAB 2: Information Extraction | |
| # ============================================================ | |
| with gr.Tab("🔍 Information Extraction", id="tab-extract"): | |
| gr.Markdown( | |
| "Extract structured data from documents using a JSON schema. " | |
| "Choose a pre-built template or write your own schema." | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2, min_width=320): | |
| extract_image = gr.Image( | |
| type="pil", | |
| label="Upload Document Image", | |
| sources=["upload", "clipboard"], | |
| height=300, | |
| ) | |
| extract_template = gr.Dropdown( | |
| choices=list(EXTRACTION_TEMPLATES.keys()), | |
| value="Receipt", | |
| label="Extraction Template", | |
| info="Pre-built JSON schemas for common document types", | |
| ) | |
| extract_custom_schema = gr.Code( | |
| label="Custom JSON Schema", | |
| language="json", | |
| lines=10, | |
| visible=False, | |
| value='{\n "field_1": "",\n "field_2": ""\n}', | |
| ) | |
| extract_btn = gr.Button( | |
| "🔍 Extract Information", | |
| variant="primary", | |
| elem_id="run-btn-extract", | |
| size="lg", | |
| ) | |
| with gr.Column(scale=3, min_width=400): | |
| extract_stats = gr.Markdown("", elem_classes=["stats-label"]) | |
| extract_output = gr.Textbox( | |
| label="Extraction Result", | |
| lines=18, | |
| max_lines=40, | |
| elem_id="output-text-extract", | |
| interactive=False, | |
| ) | |
| extract_template.change( | |
| fn=toggle_custom_schema, | |
| inputs=extract_template, | |
| outputs=extract_custom_schema, | |
| ) | |
| extract_btn.click( | |
| fn=handle_info_extraction, | |
| inputs=[extract_image, extract_template, extract_custom_schema], | |
| outputs=[extract_output, extract_stats], | |
| ) | |
| if example_paths: | |
| gr.Markdown("### 📸 Try an Example") | |
| gr.Examples( | |
| examples=[ | |
| [p, "Receipt", ""] | |
| for p in example_paths | |
| if "receipt" in os.path.basename(p).lower() | |
| ] + [ | |
| [p, "Custom (write your own)", ""] | |
| for p in example_paths[:1] | |
| ], | |
| inputs=[extract_image, extract_template, extract_custom_schema], | |
| label="Click an example to load it", | |
| cache_examples=False, | |
| ) | |
| # ============================================================ | |
| # TAB 3: Custom Prompt | |
| # ============================================================ | |
| with gr.Tab("💬 Custom Prompt", id="tab-custom"): | |
| gr.Markdown( | |
| "Send any custom prompt to the model along with an image. " | |
| "Great for experimenting with different instructions." | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2, min_width=320): | |
| custom_image = gr.Image( | |
| type="pil", | |
| label="Upload Image", | |
| sources=["upload", "clipboard"], | |
| height=300, | |
| ) | |
| custom_prompt = gr.Textbox( | |
| label="Your Prompt", | |
| placeholder="e.g., Describe the contents of this image in detail...", | |
| lines=4, | |
| ) | |
| with gr.Accordion("💡 Prompt Ideas", open=False): | |
| gr.Markdown( | |
| """ | |
| **Document Parsing prompts:** | |
| - `Text Recognition:` — extract all text | |
| - `Formula Recognition:` — extract LaTeX formulas | |
| - `Table Recognition:` — parse tables | |
| **Information Extraction (use JSON schema):** | |
| ``` | |
| 请按下列JSON格式输出图中信息: | |
| { "name": "", "date": "", "total": "" } | |
| ``` | |
| **Tips:** | |
| - Be specific about what you want extracted | |
| - For structured output, provide a JSON template | |
| - The model works best with clear, direct instructions | |
| """ | |
| ) | |
| custom_btn = gr.Button( | |
| "▶ Run", | |
| variant="primary", | |
| elem_id="run-btn-custom", | |
| size="lg", | |
| ) | |
| with gr.Column(scale=3, min_width=400): | |
| custom_stats = gr.Markdown("", elem_classes=["stats-label"]) | |
| custom_output = gr.Textbox( | |
| label="Model Output", | |
| lines=18, | |
| max_lines=40, | |
| elem_id="output-text-custom", | |
| interactive=False, | |
| ) | |
| custom_btn.click( | |
| fn=handle_custom_prompt, | |
| inputs=[custom_image, custom_prompt], | |
| outputs=[custom_output, custom_stats], | |
| ) | |
| # ============================================================ | |
| # TAB 4: About | |
| # ============================================================ | |
| with gr.Tab("ℹ️ About", id="tab-about"): | |
| gr.Markdown( | |
| """ | |
| ## About GLM-OCR | |
| **GLM-OCR** is a multimodal OCR model for complex document understanding, | |
| built on the GLM-V encoder–decoder architecture. It combines: | |
| - **CogViT** visual encoder pre-trained on large-scale image–text data | |
| - A lightweight cross-modal connector with efficient token downsampling | |
| - A **GLM-0.5B** language decoder | |
| - Multi-Token Prediction (MTP) loss and stable full-task reinforcement learning | |
| ### Key Features | |
| | Feature | Description | | |
| |---------|-------------| | |
| | **#1 on OmniDocBench V1.5** | Score of 94.62, state-of-the-art across document understanding benchmarks | | |
| | **0.9B Parameters** | Efficient inference, ideal for production and edge deployment | | |
| | **Multi-format** | Handles text, formulas, tables, code, seals, and complex layouts | | |
| | **Multiple backends** | Supports vLLM, SGLang, Ollama, and Transformers | | |
| ### Supported Tasks | |
| 1. **Text Recognition** — Extract raw text from documents and images | |
| 2. **Formula Recognition** — Convert mathematical formulas to LaTeX | |
| 3. **Table Recognition** — Parse tables into structured HTML/Markdown | |
| 4. **Information Extraction** — Extract structured JSON from documents using custom schemas | |
| ### Links | |
| - 🏠 [Model Card on Hugging Face](https://huggingface.co/zai-org/GLM-OCR) | |
| - 📦 [Official SDK on GitHub](https://github.com/zai-org/GLM-OCR) | |
| - 📄 License: **MIT** | |
| --- | |
| *Built with [Gradio](https://gradio.app) and [Transformers](https://huggingface.co/docs/transformers).* | |
| """ | |
| ) | |
| # ---- Footer ---- | |
| gr.HTML( | |
| '<div id="footer-info">' | |
| "GLM-OCR Studio • Model: zai-org/GLM-OCR • " | |
| "MIT License • Powered by Gradio" | |
| "</div>" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Launch | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| # Auto-download example images if the examples dir doesn't exist | |
| if not os.path.isdir(EXAMPLES_DIR) or len(os.listdir(EXAMPLES_DIR)) == 0: | |
| print("📥 Downloading example images...") | |
| from download_examples import download_examples | |
| download_examples() | |
| demo.queue(max_size=10).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| ) | |