""" ocr_engine.py — zai-org/GLM-OCR inference module GLM-OCR is a 0.9B multimodal OCR model built on the GLM-V encoder-decoder architecture. It uses a CogViT visual encoder + GLM-0.5B language decoder, trained with Multi-Token Prediction loss for high-quality document OCR. Model: https://huggingface.co/zai-org/GLM-OCR Paper: https://arxiv.org/abs/2603.10910 """ import io import time import logging import tempfile from dataclasses import dataclass from pathlib import Path from typing import Literal import torch import torch.nn.functional as F from PIL import Image from transformers import AutoProcessor, AutoModelForImageTextToText logger = logging.getLogger(__name__) # ── Config ───────────────────────────────────────────────────────────────── MODEL_ID = "zai-org/GLM-OCR" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Two prompt modes supported by GLM-OCR: # "recognize" → "Text Recognition:" (extract raw text, preserves structure) # "parse" → "Document Parsing:" (structured markdown output) OcrMode = Literal["recognize", "parse"] PROMPTS = { "recognize": "Text Recognition:", "parse": "Document Parsing:", } # ── Result dataclass ──────────────────────────────────────────────────────── @dataclass class OcrResult: text: str mode: str word_count: int char_count: int latency_ms: float device: str model_id: str # ── Engine ────────────────────────────────────────────────────────────────── class GlmOcrEngine: """ Wraps zai-org/GLM-OCR. Call .load() once at startup, then .run(image_bytes, mode) per request. """ def __init__(self): self.model = None self.processor = None self.loaded = False # ── Lifecycle ─────────────────────────────────────────────────────────── def load(self) -> None: if self.loaded: return logger.info(f"Loading {MODEL_ID} on {DEVICE} …") t0 = time.time() self.processor = AutoProcessor.from_pretrained( MODEL_ID, trust_remote_code=True, ) self.model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, torch_dtype="auto", # fp16 on CUDA, fp32 on CPU device_map="auto", # spreads across available devices trust_remote_code=True, ) # ── CPU patch: replace the slow Conv3d patch_embed with matmul ────── # The default Conv3d produces ~22k individual 1x1x1 kernels on CPU # which is catastrophically slow. This replaces it with a single F.linear # call, bringing CPU inference from ~30min to ~30s per image. # See: https://huggingface.co/zai-org/GLM-OCR/discussions/36 if DEVICE == "cpu": self._apply_cpu_patch() self.model.eval() self.loaded = True logger.info(f"Model loaded in {time.time() - t0:.1f}s") def _apply_cpu_patch(self): """Replace Conv3d patch_embed with matmul for fast CPU inference.""" try: base_model = self.model.model if hasattr(self.model, 'model') else self.model patch_embed = base_model.visual.patch_embed proj = patch_embed.proj in_features = ( patch_embed.in_channels * patch_embed.temporal_patch_size * patch_embed.patch_size ** 2 ) embed_dim = patch_embed.embed_dim weight = proj.weight bias = proj.bias def _fast_forward(hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = weight.dtype hidden_states = hidden_states.reshape(-1, in_features).to(dtype=target_dtype) return F.linear(hidden_states, weight.reshape(embed_dim, -1), bias) patch_embed.forward = _fast_forward logger.info("CPU matmul patch applied to patch_embed.") except Exception as e: logger.warning(f"Could not apply CPU patch (will still work, just slower): {e}") def unload(self) -> None: if self.model: del self.model del self.processor self.model = None self.processor = None self.loaded = False if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Model unloaded.") # ── Inference ─────────────────────────────────────────────────────────── def run(self, image_bytes: bytes, mode: OcrMode = "recognize") -> OcrResult: """ Run GLM-OCR on raw image bytes. Args: image_bytes: Raw bytes of the uploaded image. mode: 'recognize' → plain text extraction ("Text Recognition:") 'parse' → structured markdown output ("Document Parsing:") Returns: OcrResult with extracted text and metadata. """ if not self.loaded: raise RuntimeError("Engine not loaded. Call .load() first.") # Validate image img = self._validate_image(image_bytes) # Save to temp file — processor loads from path/URL tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) img.save(tmp.name, format="PNG") tmp.close() prompt_text = PROMPTS[mode] messages = [ { "role": "user", "content": [ {"type": "image", "url": tmp.name}, {"type": "text", "text": prompt_text}, ], } ] t0 = time.time() try: inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ).to(self.model.device) # token_type_ids not used by this model inputs.pop("token_type_ids", None) with torch.inference_mode(): generated_ids = self.model.generate( **inputs, max_new_tokens=8192, ) # Decode only the newly generated tokens output_text = self.processor.decode( generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False, ) finally: Path(tmp.name).unlink(missing_ok=True) latency_ms = (time.time() - t0) * 1000 text = output_text.strip() if output_text else "" return OcrResult( text = text, mode = mode, word_count = len(text.split()) if text else 0, char_count = len(text), latency_ms = round(latency_ms, 1), device = str(next(self.model.parameters()).device), model_id = MODEL_ID, ) # ── Helpers ───────────────────────────────────────────────────────────── @staticmethod def _validate_image(image_bytes: bytes) -> Image.Image: try: img = Image.open(io.BytesIO(image_bytes)) img.verify() img = Image.open(io.BytesIO(image_bytes)) return img.convert("RGB") except Exception as e: raise ValueError(f"Invalid image: {e}") from e @property def info(self) -> dict: return { "model_id": MODEL_ID, "device": DEVICE, "loaded": self.loaded, "cuda_available": torch.cuda.is_available(), "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, "gpu_memory_gb": round( torch.cuda.get_device_properties(0).total_memory / 1e9, 1 ) if torch.cuda.is_available() else None, } # ── Singleton ─────────────────────────────────────────────────────────────── engine = GlmOcrEngine()