GLM-OCR-Studio / app.py
singhankit16's picture
gradio_app.py
753f4eb verified
import os
import glob
import json
import time
import torch
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_PATH = "zai-org/GLM-OCR"
EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "examples")
MAX_NEW_TOKENS = 8192
# ---------------------------------------------------------------------------
# Model loading (once at startup)
# ---------------------------------------------------------------------------
print("=" * 60)
print(" Loading GLM-OCR model...")
print("=" * 60)
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
pretrained_model_name_or_path=MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.eval()
print("✅ GLM-OCR model loaded successfully!\n")
# ---------------------------------------------------------------------------
# Prompt templates
# ---------------------------------------------------------------------------
TASK_PROMPTS = {
"Text Recognition": "Text Recognition:",
"Formula Recognition": "Formula Recognition:",
"Table Recognition": "Table Recognition:",
}
# Pre-built extraction schemas for the demo
EXTRACTION_TEMPLATES = {
"Custom (write your own)": "",
"ID Card (English)": json.dumps(
{
"id_number": "",
"last_name": "",
"first_name": "",
"date_of_birth": "",
"address": {
"street": "",
"city": "",
"state": "",
"zip_code": "",
},
"dates": {"issue_date": "", "expiration_date": ""},
"sex": "",
},
indent=2,
),
"Invoice": json.dumps(
{
"invoice_number": "",
"date": "",
"vendor": "",
"items": [{"description": "", "quantity": "", "unit_price": "", "amount": ""}],
"subtotal": "",
"tax": "",
"total": "",
},
indent=2,
),
"Business Card": json.dumps(
{
"name": "",
"title": "",
"company": "",
"phone": "",
"email": "",
"address": "",
"website": "",
},
indent=2,
),
"Receipt": json.dumps(
{
"store_name": "",
"date": "",
"items": [{"name": "", "price": ""}],
"subtotal": "",
"tax": "",
"total": "",
"payment_method": "",
},
indent=2,
),
}
# ---------------------------------------------------------------------------
# Inference helpers
# ---------------------------------------------------------------------------
def run_ocr(image: Image.Image, prompt_text: str) -> tuple[str, float]:
"""Run GLM-OCR inference and return (output_text, elapsed_seconds)."""
if image is None:
raise gr.Error("Please upload an image first.")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt_text},
],
}
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
inputs.pop("token_type_ids", None)
t0 = time.perf_counter()
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS)
elapsed = time.perf_counter() - t0
output_text = processor.decode(
generated_ids[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=False,
)
# Strip the end-of-sequence tokens for cleaner output
for tok in ["<|endoftext|>", "</s>", "<|im_end|>", "<|end|>"]:
output_text = output_text.replace(tok, "")
return output_text.strip(), elapsed
# ---------------------------------------------------------------------------
# Tab handlers
# ---------------------------------------------------------------------------
def handle_document_parsing(image, task_name):
"""Handle document parsing tasks (text, formula, table)."""
prompt = TASK_PROMPTS[task_name]
result, elapsed = run_ocr(image, prompt)
stats = f"⏱ {elapsed:.2f}s"
return result, stats
def handle_info_extraction(image, template_name, custom_schema):
"""Handle information extraction with a JSON schema prompt."""
if template_name == "Custom (write your own)":
schema_text = custom_schema
else:
schema_text = EXTRACTION_TEMPLATES[template_name]
if not schema_text or schema_text.strip() == "":
raise gr.Error("Please provide a JSON schema for extraction.")
prompt = f"请按下列JSON格式输出图中信息:\n{schema_text}"
result, elapsed = run_ocr(image, prompt)
stats = f"⏱ {elapsed:.2f}s"
return result, stats
def handle_custom_prompt(image, custom_prompt):
"""Handle a completely custom prompt entered by the user."""
if not custom_prompt or custom_prompt.strip() == "":
raise gr.Error("Please enter a prompt.")
result, elapsed = run_ocr(image, custom_prompt.strip())
stats = f"⏱ {elapsed:.2f}s"
return result, stats
# ---------------------------------------------------------------------------
# UI helpers
# ---------------------------------------------------------------------------
def toggle_custom_schema(template_name):
"""Show/hide the custom schema textbox based on template selection."""
return gr.update(visible=(template_name == "Custom (write your own)"))
def get_example_images():
"""Return list of example image paths if they exist."""
if not os.path.isdir(EXAMPLES_DIR):
return []
exts = ["*.png", "*.jpg", "*.jpeg", "*.webp", "*.bmp"]
paths = []
for ext in exts:
paths.extend(glob.glob(os.path.join(EXAMPLES_DIR, ext)))
return sorted(paths)
# ---------------------------------------------------------------------------
# Custom CSS for a polished, unique look
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
/* Global */
.gradio-container {
max-width: 1280px !important;
margin: auto;
}
/* Header banner */
#header-banner {
background: linear-gradient(135deg, #0f172a 0%, #1e3a5f 50%, #0ea5e9 100%);
border-radius: 16px;
padding: 28px 36px;
margin-bottom: 16px;
color: white;
text-align: center;
}
#header-banner h1 {
font-size: 2.2rem;
font-weight: 800;
margin: 0 0 4px 0;
letter-spacing: -0.5px;
}
#header-banner p {
margin: 4px 0 0 0;
opacity: 0.85;
font-size: 1rem;
}
/* Stat badges */
.stat-badge {
display: inline-block;
background: rgba(14, 165, 233, 0.15);
border: 1px solid rgba(14, 165, 233, 0.3);
border-radius: 8px;
padding: 4px 14px;
font-size: 0.92rem;
color: #0ea5e9;
font-weight: 600;
}
/* Task cards */
.task-card {
border: 1px solid #e2e8f0;
border-radius: 12px;
padding: 16px;
transition: box-shadow 0.2s;
}
.task-card:hover {
box-shadow: 0 4px 16px rgba(14,165,233,0.10);
}
/* Run button */
#run-btn, #run-btn-extract, #run-btn-custom {
background: linear-gradient(135deg, #0ea5e9, #2563eb) !important;
color: white !important;
font-weight: 700 !important;
font-size: 1.05rem !important;
border-radius: 10px !important;
padding: 12px 0px !important;
border: none !important;
transition: transform 0.15s, box-shadow 0.15s !important;
}
#run-btn:hover, #run-btn-extract:hover, #run-btn-custom:hover {
transform: translateY(-1px) !important;
box-shadow: 0 6px 20px rgba(14,165,233,0.25) !important;
}
/* Output text */
#output-text textarea, #output-text-extract textarea, #output-text-custom textarea {
font-family: 'Cascadia Code', 'Fira Code', 'Consolas', monospace !important;
font-size: 0.92rem !important;
line-height: 1.6 !important;
}
/* Stats label */
.stats-label {
font-size: 0.9rem;
color: #64748b;
font-weight: 500;
}
/* Footer */
#footer-info {
text-align: center;
padding: 12px;
color: #94a3b8;
font-size: 0.85rem;
}
"""
# ---------------------------------------------------------------------------
# Build the Gradio app
# ---------------------------------------------------------------------------
with gr.Blocks(
title="GLM-OCR Studio",
theme=gr.themes.Base(
primary_hue=gr.themes.colors.sky,
secondary_hue=gr.themes.colors.blue,
neutral_hue=gr.themes.colors.slate,
font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"],
font_mono=[gr.themes.GoogleFont("Fira Code"), "Consolas", "monospace"],
),
css=CUSTOM_CSS,
) as demo:
# ---- Header ----
gr.HTML(
"""
<div id="header-banner">
<h1>GLM-OCR Studio</h1>
<p>Powered by <strong>zai-org/GLM-OCR</strong> — 0.9B parameters, state-of-the-art document understanding</p>
<div style="margin-top:12px; display:flex; gap:12px; justify-content:center; flex-wrap:wrap;">
<span class="stat-badge">📝 Text Recognition</span>
<span class="stat-badge">📐 Formula Recognition</span>
<span class="stat-badge">📊 Table Recognition</span>
<span class="stat-badge">🔍 Information Extraction</span>
<span class="stat-badge">💬 Custom Prompts</span>
</div>
</div>
""",
)
# ---- Main Tabs ----
with gr.Tabs() as main_tabs:
# ============================================================
# TAB 1: Document Parsing
# ============================================================
with gr.Tab("📄 Document Parsing", id="tab-parse"):
gr.Markdown(
"Upload an image and choose a parsing mode. The model will extract "
"text, formulas, or tables from your document."
)
with gr.Row(equal_height=True):
# -- Left column: inputs --
with gr.Column(scale=2, min_width=320):
parse_image = gr.Image(
type="pil",
label="Upload Document Image",
sources=["upload", "clipboard"],
height=380,
)
parse_task = gr.Radio(
choices=list(TASK_PROMPTS.keys()),
value="Text Recognition",
label="Parsing Mode",
info="Select what to extract from the image",
)
parse_btn = gr.Button(
"🚀 Run OCR",
variant="primary",
elem_id="run-btn",
size="lg",
)
# -- Right column: outputs --
with gr.Column(scale=3, min_width=400):
parse_stats = gr.Markdown("", elem_classes=["stats-label"])
parse_output = gr.Textbox(
label="Recognition Result",
lines=18,
max_lines=40,
elem_id="output-text",
interactive=False,
)
parse_btn.click(
fn=handle_document_parsing,
inputs=[parse_image, parse_task],
outputs=[parse_output, parse_stats],
)
# -- Examples --
example_paths = get_example_images()
if example_paths:
gr.Markdown("### 📸 Try an Example")
gr.Examples(
examples=[
[p, "Text Recognition"]
for p in example_paths[:3]
] + [
[p, "Table Recognition"]
for p in example_paths
if "table" in os.path.basename(p).lower()
] + [
[p, "Formula Recognition"]
for p in example_paths
if "formula" in os.path.basename(p).lower()
],
inputs=[parse_image, parse_task],
label="Click an example to load it",
cache_examples=False,
)
# ============================================================
# TAB 2: Information Extraction
# ============================================================
with gr.Tab("🔍 Information Extraction", id="tab-extract"):
gr.Markdown(
"Extract structured data from documents using a JSON schema. "
"Choose a pre-built template or write your own schema."
)
with gr.Row(equal_height=True):
with gr.Column(scale=2, min_width=320):
extract_image = gr.Image(
type="pil",
label="Upload Document Image",
sources=["upload", "clipboard"],
height=300,
)
extract_template = gr.Dropdown(
choices=list(EXTRACTION_TEMPLATES.keys()),
value="Receipt",
label="Extraction Template",
info="Pre-built JSON schemas for common document types",
)
extract_custom_schema = gr.Code(
label="Custom JSON Schema",
language="json",
lines=10,
visible=False,
value='{\n "field_1": "",\n "field_2": ""\n}',
)
extract_btn = gr.Button(
"🔍 Extract Information",
variant="primary",
elem_id="run-btn-extract",
size="lg",
)
with gr.Column(scale=3, min_width=400):
extract_stats = gr.Markdown("", elem_classes=["stats-label"])
extract_output = gr.Textbox(
label="Extraction Result",
lines=18,
max_lines=40,
elem_id="output-text-extract",
interactive=False,
)
extract_template.change(
fn=toggle_custom_schema,
inputs=extract_template,
outputs=extract_custom_schema,
)
extract_btn.click(
fn=handle_info_extraction,
inputs=[extract_image, extract_template, extract_custom_schema],
outputs=[extract_output, extract_stats],
)
if example_paths:
gr.Markdown("### 📸 Try an Example")
gr.Examples(
examples=[
[p, "Receipt", ""]
for p in example_paths
if "receipt" in os.path.basename(p).lower()
] + [
[p, "Custom (write your own)", ""]
for p in example_paths[:1]
],
inputs=[extract_image, extract_template, extract_custom_schema],
label="Click an example to load it",
cache_examples=False,
)
# ============================================================
# TAB 3: Custom Prompt
# ============================================================
with gr.Tab("💬 Custom Prompt", id="tab-custom"):
gr.Markdown(
"Send any custom prompt to the model along with an image. "
"Great for experimenting with different instructions."
)
with gr.Row(equal_height=True):
with gr.Column(scale=2, min_width=320):
custom_image = gr.Image(
type="pil",
label="Upload Image",
sources=["upload", "clipboard"],
height=300,
)
custom_prompt = gr.Textbox(
label="Your Prompt",
placeholder="e.g., Describe the contents of this image in detail...",
lines=4,
)
with gr.Accordion("💡 Prompt Ideas", open=False):
gr.Markdown(
"""
**Document Parsing prompts:**
- `Text Recognition:` — extract all text
- `Formula Recognition:` — extract LaTeX formulas
- `Table Recognition:` — parse tables
**Information Extraction (use JSON schema):**
```
请按下列JSON格式输出图中信息:
{ "name": "", "date": "", "total": "" }
```
**Tips:**
- Be specific about what you want extracted
- For structured output, provide a JSON template
- The model works best with clear, direct instructions
"""
)
custom_btn = gr.Button(
"▶ Run",
variant="primary",
elem_id="run-btn-custom",
size="lg",
)
with gr.Column(scale=3, min_width=400):
custom_stats = gr.Markdown("", elem_classes=["stats-label"])
custom_output = gr.Textbox(
label="Model Output",
lines=18,
max_lines=40,
elem_id="output-text-custom",
interactive=False,
)
custom_btn.click(
fn=handle_custom_prompt,
inputs=[custom_image, custom_prompt],
outputs=[custom_output, custom_stats],
)
# ============================================================
# TAB 4: About
# ============================================================
with gr.Tab("ℹ️ About", id="tab-about"):
gr.Markdown(
"""
## About GLM-OCR
**GLM-OCR** is a multimodal OCR model for complex document understanding,
built on the GLM-V encoder–decoder architecture. It combines:
- **CogViT** visual encoder pre-trained on large-scale image–text data
- A lightweight cross-modal connector with efficient token downsampling
- A **GLM-0.5B** language decoder
- Multi-Token Prediction (MTP) loss and stable full-task reinforcement learning
### Key Features
| Feature | Description |
|---------|-------------|
| **#1 on OmniDocBench V1.5** | Score of 94.62, state-of-the-art across document understanding benchmarks |
| **0.9B Parameters** | Efficient inference, ideal for production and edge deployment |
| **Multi-format** | Handles text, formulas, tables, code, seals, and complex layouts |
| **Multiple backends** | Supports vLLM, SGLang, Ollama, and Transformers |
### Supported Tasks
1. **Text Recognition** — Extract raw text from documents and images
2. **Formula Recognition** — Convert mathematical formulas to LaTeX
3. **Table Recognition** — Parse tables into structured HTML/Markdown
4. **Information Extraction** — Extract structured JSON from documents using custom schemas
### Links
- 🏠 [Model Card on Hugging Face](https://huggingface.co/zai-org/GLM-OCR)
- 📦 [Official SDK on GitHub](https://github.com/zai-org/GLM-OCR)
- 📄 License: **MIT**
---
*Built with [Gradio](https://gradio.app) and [Transformers](https://huggingface.co/docs/transformers).*
"""
)
# ---- Footer ----
gr.HTML(
'<div id="footer-info">'
"GLM-OCR Studio &bull; Model: zai-org/GLM-OCR &bull; "
"MIT License &bull; Powered by Gradio"
"</div>"
)
# ---------------------------------------------------------------------------
# Launch
# ---------------------------------------------------------------------------
if __name__ == "__main__":
# Auto-download example images if the examples dir doesn't exist
if not os.path.isdir(EXAMPLES_DIR) or len(os.listdir(EXAMPLES_DIR)) == 0:
print("📥 Downloading example images...")
from download_examples import download_examples
download_examples()
demo.queue(max_size=10).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
)