Spaces:
Sleeping
Sleeping
| """ | |
| ocr_engine.py β zai-org/GLM-OCR inference module | |
| GLM-OCR is a 0.9B multimodal OCR model built on the GLM-V encoder-decoder | |
| architecture. It uses a CogViT visual encoder + GLM-0.5B language decoder, | |
| trained with Multi-Token Prediction loss for high-quality document OCR. | |
| Model: https://huggingface.co/zai-org/GLM-OCR | |
| Paper: https://arxiv.org/abs/2603.10910 | |
| """ | |
| import io | |
| import time | |
| import logging | |
| import tempfile | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Literal | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| logger = logging.getLogger(__name__) | |
| # ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_ID = "zai-org/GLM-OCR" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Two prompt modes supported by GLM-OCR: | |
| # "recognize" β "Text Recognition:" (extract raw text, preserves structure) | |
| # "parse" β "Document Parsing:" (structured markdown output) | |
| OcrMode = Literal["recognize", "parse"] | |
| PROMPTS = { | |
| "recognize": "Text Recognition:", | |
| "parse": "Document Parsing:", | |
| } | |
| # ββ Result dataclass ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class OcrResult: | |
| text: str | |
| mode: str | |
| word_count: int | |
| char_count: int | |
| latency_ms: float | |
| device: str | |
| model_id: str | |
| # ββ Engine ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class GlmOcrEngine: | |
| """ | |
| Wraps zai-org/GLM-OCR. Call .load() once at startup, | |
| then .run(image_bytes, mode) per request. | |
| """ | |
| def __init__(self): | |
| self.model = None | |
| self.processor = None | |
| self.loaded = False | |
| # ββ Lifecycle βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load(self) -> None: | |
| if self.loaded: | |
| return | |
| logger.info(f"Loading {MODEL_ID} on {DEVICE} β¦") | |
| t0 = time.time() | |
| self.processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| ) | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype="auto", # fp16 on CUDA, fp32 on CPU | |
| device_map="auto", # spreads across available devices | |
| trust_remote_code=True, | |
| ) | |
| # ββ CPU patch: replace the slow Conv3d patch_embed with matmul ββββββ | |
| # The default Conv3d produces ~22k individual 1x1x1 kernels on CPU | |
| # which is catastrophically slow. This replaces it with a single F.linear | |
| # call, bringing CPU inference from ~30min to ~30s per image. | |
| # See: https://huggingface.co/zai-org/GLM-OCR/discussions/36 | |
| if DEVICE == "cpu": | |
| self._apply_cpu_patch() | |
| self.model.eval() | |
| self.loaded = True | |
| logger.info(f"Model loaded in {time.time() - t0:.1f}s") | |
| def _apply_cpu_patch(self): | |
| """Replace Conv3d patch_embed with matmul for fast CPU inference.""" | |
| try: | |
| base_model = self.model.model if hasattr(self.model, 'model') else self.model | |
| patch_embed = base_model.visual.patch_embed | |
| proj = patch_embed.proj | |
| in_features = ( | |
| patch_embed.in_channels * | |
| patch_embed.temporal_patch_size * | |
| patch_embed.patch_size ** 2 | |
| ) | |
| embed_dim = patch_embed.embed_dim | |
| weight = proj.weight | |
| bias = proj.bias | |
| def _fast_forward(hidden_states: torch.Tensor) -> torch.Tensor: | |
| target_dtype = weight.dtype | |
| hidden_states = hidden_states.reshape(-1, in_features).to(dtype=target_dtype) | |
| return F.linear(hidden_states, weight.reshape(embed_dim, -1), bias) | |
| patch_embed.forward = _fast_forward | |
| logger.info("CPU matmul patch applied to patch_embed.") | |
| except Exception as e: | |
| logger.warning(f"Could not apply CPU patch (will still work, just slower): {e}") | |
| def unload(self) -> None: | |
| if self.model: | |
| del self.model | |
| del self.processor | |
| self.model = None | |
| self.processor = None | |
| self.loaded = False | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("Model unloaded.") | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run(self, image_bytes: bytes, mode: OcrMode = "recognize") -> OcrResult: | |
| """ | |
| Run GLM-OCR on raw image bytes. | |
| Args: | |
| image_bytes: Raw bytes of the uploaded image. | |
| mode: | |
| 'recognize' β plain text extraction ("Text Recognition:") | |
| 'parse' β structured markdown output ("Document Parsing:") | |
| Returns: | |
| OcrResult with extracted text and metadata. | |
| """ | |
| if not self.loaded: | |
| raise RuntimeError("Engine not loaded. Call .load() first.") | |
| # Validate image | |
| img = self._validate_image(image_bytes) | |
| # Save to temp file β processor loads from path/URL | |
| tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| img.save(tmp.name, format="PNG") | |
| tmp.close() | |
| prompt_text = PROMPTS[mode] | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "url": tmp.name}, | |
| {"type": "text", "text": prompt_text}, | |
| ], | |
| } | |
| ] | |
| t0 = time.time() | |
| try: | |
| inputs = self.processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(self.model.device) | |
| # token_type_ids not used by this model | |
| inputs.pop("token_type_ids", None) | |
| with torch.inference_mode(): | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=8192, | |
| ) | |
| # Decode only the newly generated tokens | |
| output_text = self.processor.decode( | |
| generated_ids[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=False, | |
| ) | |
| finally: | |
| Path(tmp.name).unlink(missing_ok=True) | |
| latency_ms = (time.time() - t0) * 1000 | |
| text = output_text.strip() if output_text else "" | |
| return OcrResult( | |
| text = text, | |
| mode = mode, | |
| word_count = len(text.split()) if text else 0, | |
| char_count = len(text), | |
| latency_ms = round(latency_ms, 1), | |
| device = str(next(self.model.parameters()).device), | |
| model_id = MODEL_ID, | |
| ) | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _validate_image(image_bytes: bytes) -> Image.Image: | |
| try: | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| img.verify() | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| return img.convert("RGB") | |
| except Exception as e: | |
| raise ValueError(f"Invalid image: {e}") from e | |
| def info(self) -> dict: | |
| return { | |
| "model_id": MODEL_ID, | |
| "device": DEVICE, | |
| "loaded": self.loaded, | |
| "cuda_available": torch.cuda.is_available(), | |
| "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, | |
| "gpu_memory_gb": round( | |
| torch.cuda.get_device_properties(0).total_memory / 1e9, 1 | |
| ) if torch.cuda.is_available() else None, | |
| } | |
| # ββ Singleton βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| engine = GlmOcrEngine() |