import os import base64 from io import BytesIO import warnings import time from typing import Union import torch from PIL import Image import gradio as gr from transformers import AutoProcessor, AutoModelForVision2Seq # ----------------------------------------------------------------------------- # Environment + warnings (quiet startup) # ----------------------------------------------------------------------------- os.environ["OMP_NUM_THREADS"] = "1" os.environ["TRANSFORMERS_VERBOSITY"] = "error" os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings("ignore") # ----------------------------------------------------------------------------- # Model config # ----------------------------------------------------------------------------- MODEL_ID = "allenai/olmOCR-2-7B-1025" processor = None model = None def load_model(): """Lazy-load model so Space boots fast.""" global processor, model if processor is not None and model is not None: return processor = AutoProcessor.from_pretrained( MODEL_ID, trust_remote_code=True, ) model = AutoModelForVision2Seq.from_pretrained( MODEL_ID, dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=True, ).eval() print("✅ olmOCR-2 model loaded") # ----------------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------------- def _resize_max_side(img: Image.Image, max_side: int = 896) -> Image.Image: w, h = img.size m = max(w, h) if m <= max_side: return img scale = max_side / m return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS) def build_prompt(width: int, height: int) -> str: return ( "Extract all readable text from this page image.\n" "Return ONLY the extracted text (no explanations, no markdown).\n" "Do not hallucinate.\n" "RAW_TEXT_START\n" f"Page dimensions: {width:.1f}x{height:.1f} " f"[Image 0x0 to {width:.1f}x{height:.1f}]\n" "RAW_TEXT_END" ) def _coerce_to_pil(img: Union[Image.Image, dict, str]) -> Image.Image: """ Normalize Gradio UI input and gradio_client input into a PIL Image. """ if isinstance(img, Image.Image): return img if isinstance(img, str): return Image.open(img) if isinstance(img, dict): path = img.get("path") if path: return Image.open(path) url = img.get("url") if url and url.startswith("data:image"): _, b64 = url.split(",", 1) return Image.open(BytesIO(base64.b64decode(b64))) raise ValueError(f"Unsupported image input: {type(img)}") # ----------------------------------------------------------------------------- # OCR function (API) # ----------------------------------------------------------------------------- def ocr_image(img: Union[Image.Image, dict, str]) -> tuple[str, str]: if img is None: return "No image uploaded.", "0.0s" start = time.perf_counter() load_model() try: img = _coerce_to_pil(img) except Exception as e: return f"Invalid image input: {e}", "0.0s" img = img.convert("RGB") img = _resize_max_side(img) w, h = img.size # Build prompt prompt = build_prompt(w, h) # Encode image for VLM message buf = BytesIO() img.save(buf, format="PNG") image_b64 = base64.b64encode(buf.getvalue()).decode() messages = [ { "role": "user", "content": [ {"type": "text", "text": prompt}, { "type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}, }, ], } ] chat_text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = processor( text=[chat_text], images=[img], padding=True, return_tensors="pt", ) inputs = { k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items() } with torch.inference_mode(): output_ids = model.generate( **inputs, max_new_tokens=512, do_sample=False, ) prompt_len = inputs["input_ids"].shape[1] gen_ids = output_ids[:, prompt_len:] text = processor.tokenizer.batch_decode( gen_ids, skip_special_tokens=True ) elapsed = time.perf_counter() - start return (text[0].strip() if text else "No text extracted.", f"{elapsed:.2f}s") # ----------------------------------------------------------------------------- # Gradio UI + API # ----------------------------------------------------------------------------- with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo: gr.Markdown( "# 📖 BookReader OCR API (olmOCR2)\n" "Upload an image and extract text using **olmOCR-2-7B**.\n\n" "**API endpoint:** `ocr`" ) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload image") run_btn = gr.Button("Run OCR", variant="primary") with gr.Column(): output = gr.Textbox(label="Extracted text", lines=15) timing = gr.Textbox(label="Generation time", interactive=False) run_btn.click( fn=ocr_image, inputs=image_input, outputs=[output, timing], api_name="ocr", ) if __name__ == "__main__": demo.queue().launch(show_error=True)