GLMOCR_Text_extraction / ocr_engine.py
Sam20202's picture
Initial deploy
0533780
"""
ocr_engine.py β€” zai-org/GLM-OCR inference module
GLM-OCR is a 0.9B multimodal OCR model built on the GLM-V encoder-decoder
architecture. It uses a CogViT visual encoder + GLM-0.5B language decoder,
trained with Multi-Token Prediction loss for high-quality document OCR.
Model: https://huggingface.co/zai-org/GLM-OCR
Paper: https://arxiv.org/abs/2603.10910
"""
import io
import time
import logging
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText
logger = logging.getLogger(__name__)
# ── Config ─────────────────────────────────────────────────────────────────
MODEL_ID = "zai-org/GLM-OCR"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Two prompt modes supported by GLM-OCR:
# "recognize" β†’ "Text Recognition:" (extract raw text, preserves structure)
# "parse" β†’ "Document Parsing:" (structured markdown output)
OcrMode = Literal["recognize", "parse"]
PROMPTS = {
"recognize": "Text Recognition:",
"parse": "Document Parsing:",
}
# ── Result dataclass ────────────────────────────────────────────────────────
@dataclass
class OcrResult:
text: str
mode: str
word_count: int
char_count: int
latency_ms: float
device: str
model_id: str
# ── Engine ──────────────────────────────────────────────────────────────────
class GlmOcrEngine:
"""
Wraps zai-org/GLM-OCR. Call .load() once at startup,
then .run(image_bytes, mode) per request.
"""
def __init__(self):
self.model = None
self.processor = None
self.loaded = False
# ── Lifecycle ───────────────────────────────────────────────────────────
def load(self) -> None:
if self.loaded:
return
logger.info(f"Loading {MODEL_ID} on {DEVICE} …")
t0 = time.time()
self.processor = AutoProcessor.from_pretrained(
MODEL_ID,
trust_remote_code=True,
)
self.model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype="auto", # fp16 on CUDA, fp32 on CPU
device_map="auto", # spreads across available devices
trust_remote_code=True,
)
# ── CPU patch: replace the slow Conv3d patch_embed with matmul ──────
# The default Conv3d produces ~22k individual 1x1x1 kernels on CPU
# which is catastrophically slow. This replaces it with a single F.linear
# call, bringing CPU inference from ~30min to ~30s per image.
# See: https://huggingface.co/zai-org/GLM-OCR/discussions/36
if DEVICE == "cpu":
self._apply_cpu_patch()
self.model.eval()
self.loaded = True
logger.info(f"Model loaded in {time.time() - t0:.1f}s")
def _apply_cpu_patch(self):
"""Replace Conv3d patch_embed with matmul for fast CPU inference."""
try:
base_model = self.model.model if hasattr(self.model, 'model') else self.model
patch_embed = base_model.visual.patch_embed
proj = patch_embed.proj
in_features = (
patch_embed.in_channels *
patch_embed.temporal_patch_size *
patch_embed.patch_size ** 2
)
embed_dim = patch_embed.embed_dim
weight = proj.weight
bias = proj.bias
def _fast_forward(hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = weight.dtype
hidden_states = hidden_states.reshape(-1, in_features).to(dtype=target_dtype)
return F.linear(hidden_states, weight.reshape(embed_dim, -1), bias)
patch_embed.forward = _fast_forward
logger.info("CPU matmul patch applied to patch_embed.")
except Exception as e:
logger.warning(f"Could not apply CPU patch (will still work, just slower): {e}")
def unload(self) -> None:
if self.model:
del self.model
del self.processor
self.model = None
self.processor = None
self.loaded = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Model unloaded.")
# ── Inference ───────────────────────────────────────────────────────────
def run(self, image_bytes: bytes, mode: OcrMode = "recognize") -> OcrResult:
"""
Run GLM-OCR on raw image bytes.
Args:
image_bytes: Raw bytes of the uploaded image.
mode:
'recognize' β†’ plain text extraction ("Text Recognition:")
'parse' β†’ structured markdown output ("Document Parsing:")
Returns:
OcrResult with extracted text and metadata.
"""
if not self.loaded:
raise RuntimeError("Engine not loaded. Call .load() first.")
# Validate image
img = self._validate_image(image_bytes)
# Save to temp file β€” processor loads from path/URL
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(tmp.name, format="PNG")
tmp.close()
prompt_text = PROMPTS[mode]
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": tmp.name},
{"type": "text", "text": prompt_text},
],
}
]
t0 = time.time()
try:
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device)
# token_type_ids not used by this model
inputs.pop("token_type_ids", None)
with torch.inference_mode():
generated_ids = self.model.generate(
**inputs,
max_new_tokens=8192,
)
# Decode only the newly generated tokens
output_text = self.processor.decode(
generated_ids[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=False,
)
finally:
Path(tmp.name).unlink(missing_ok=True)
latency_ms = (time.time() - t0) * 1000
text = output_text.strip() if output_text else ""
return OcrResult(
text = text,
mode = mode,
word_count = len(text.split()) if text else 0,
char_count = len(text),
latency_ms = round(latency_ms, 1),
device = str(next(self.model.parameters()).device),
model_id = MODEL_ID,
)
# ── Helpers ─────────────────────────────────────────────────────────────
@staticmethod
def _validate_image(image_bytes: bytes) -> Image.Image:
try:
img = Image.open(io.BytesIO(image_bytes))
img.verify()
img = Image.open(io.BytesIO(image_bytes))
return img.convert("RGB")
except Exception as e:
raise ValueError(f"Invalid image: {e}") from e
@property
def info(self) -> dict:
return {
"model_id": MODEL_ID,
"device": DEVICE,
"loaded": self.loaded,
"cuda_available": torch.cuda.is_available(),
"gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
"gpu_memory_gb": round(
torch.cuda.get_device_properties(0).total_memory / 1e9, 1
) if torch.cuda.is_available() else None,
}
# ── Singleton ───────────────────────────────────────────────────────────────
engine = GlmOcrEngine()