Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| from io import BytesIO | |
| import warnings | |
| import time | |
| from typing import Union | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| # ----------------------------------------------------------------------------- | |
| # Environment + warnings (quiet startup) | |
| # ----------------------------------------------------------------------------- | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ["TRANSFORMERS_VERBOSITY"] = "error" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore") | |
| # ----------------------------------------------------------------------------- | |
| # Model config | |
| # ----------------------------------------------------------------------------- | |
| MODEL_ID = "allenai/olmOCR-2-7B-1025" | |
| processor = None | |
| model = None | |
| def load_model(): | |
| """Lazy-load model so Space boots fast.""" | |
| global processor, model | |
| if processor is not None and model is not None: | |
| return | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| ) | |
| model = AutoModelForVision2Seq.from_pretrained( | |
| MODEL_ID, | |
| dtype=torch.float16, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ).eval() | |
| print("✅ olmOCR-2 model loaded") | |
| # ----------------------------------------------------------------------------- | |
| # Helpers | |
| # ----------------------------------------------------------------------------- | |
| def _resize_max_side(img: Image.Image, max_side: int = 896) -> Image.Image: | |
| w, h = img.size | |
| m = max(w, h) | |
| if m <= max_side: | |
| return img | |
| scale = max_side / m | |
| return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS) | |
| def build_prompt(width: int, height: int) -> str: | |
| return ( | |
| "Extract all readable text from this page image.\n" | |
| "Return ONLY the extracted text (no explanations, no markdown).\n" | |
| "Do not hallucinate.\n" | |
| "RAW_TEXT_START\n" | |
| f"Page dimensions: {width:.1f}x{height:.1f} " | |
| f"[Image 0x0 to {width:.1f}x{height:.1f}]\n" | |
| "RAW_TEXT_END" | |
| ) | |
| def _coerce_to_pil(img: Union[Image.Image, dict, str]) -> Image.Image: | |
| """ | |
| Normalize Gradio UI input and gradio_client input into a PIL Image. | |
| """ | |
| if isinstance(img, Image.Image): | |
| return img | |
| if isinstance(img, str): | |
| return Image.open(img) | |
| if isinstance(img, dict): | |
| path = img.get("path") | |
| if path: | |
| return Image.open(path) | |
| url = img.get("url") | |
| if url and url.startswith("data:image"): | |
| _, b64 = url.split(",", 1) | |
| return Image.open(BytesIO(base64.b64decode(b64))) | |
| raise ValueError(f"Unsupported image input: {type(img)}") | |
| # ----------------------------------------------------------------------------- | |
| # OCR function (API) | |
| # ----------------------------------------------------------------------------- | |
| def ocr_image(img: Union[Image.Image, dict, str]) -> tuple[str, str]: | |
| if img is None: | |
| return "No image uploaded.", "0.0s" | |
| start = time.perf_counter() | |
| load_model() | |
| try: | |
| img = _coerce_to_pil(img) | |
| except Exception as e: | |
| return f"Invalid image input: {e}", "0.0s" | |
| img = img.convert("RGB") | |
| img = _resize_max_side(img) | |
| w, h = img.size | |
| # Build prompt | |
| prompt = build_prompt(w, h) | |
| # Encode image for VLM message | |
| buf = BytesIO() | |
| img.save(buf, format="PNG") | |
| image_b64 = base64.b64encode(buf.getvalue()).decode() | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/png;base64,{image_b64}"}, | |
| }, | |
| ], | |
| } | |
| ] | |
| chat_text = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = processor( | |
| text=[chat_text], | |
| images=[img], | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = { | |
| k: v.to(model.device) if torch.is_tensor(v) else v | |
| for k, v in inputs.items() | |
| } | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| ) | |
| prompt_len = inputs["input_ids"].shape[1] | |
| gen_ids = output_ids[:, prompt_len:] | |
| text = processor.tokenizer.batch_decode( | |
| gen_ids, skip_special_tokens=True | |
| ) | |
| elapsed = time.perf_counter() - start | |
| return (text[0].strip() if text else "No text extracted.", f"{elapsed:.2f}s") | |
| # ----------------------------------------------------------------------------- | |
| # Gradio UI + API | |
| # ----------------------------------------------------------------------------- | |
| with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo: | |
| gr.Markdown( | |
| "# 📖 BookReader OCR API (olmOCR2)\n" | |
| "Upload an image and extract text using **olmOCR-2-7B**.\n\n" | |
| "**API endpoint:** `ocr`" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload image") | |
| run_btn = gr.Button("Run OCR", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Extracted text", lines=15) | |
| timing = gr.Textbox(label="Generation time", interactive=False) | |
| run_btn.click( | |
| fn=ocr_image, | |
| inputs=image_input, | |
| outputs=[output, timing], | |
| api_name="ocr", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(show_error=True) | |