olmocr-api / app.py
chitrark's picture
update ocr path
23bf303 verified
import os
import base64
from io import BytesIO
import warnings
import time
from typing import Union
import torch
from PIL import Image
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
# -----------------------------------------------------------------------------
# Environment + warnings (quiet startup)
# -----------------------------------------------------------------------------
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
# -----------------------------------------------------------------------------
# Model config
# -----------------------------------------------------------------------------
MODEL_ID = "allenai/olmOCR-2-7B-1025"
processor = None
model = None
def load_model():
"""Lazy-load model so Space boots fast."""
global processor, model
if processor is not None and model is not None:
return
processor = AutoProcessor.from_pretrained(
MODEL_ID,
trust_remote_code=True,
)
model = AutoModelForVision2Seq.from_pretrained(
MODEL_ID,
dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
).eval()
print("✅ olmOCR-2 model loaded")
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def _resize_max_side(img: Image.Image, max_side: int = 896) -> Image.Image:
w, h = img.size
m = max(w, h)
if m <= max_side:
return img
scale = max_side / m
return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
def build_prompt(width: int, height: int) -> str:
return (
"Extract all readable text from this page image.\n"
"Return ONLY the extracted text (no explanations, no markdown).\n"
"Do not hallucinate.\n"
"RAW_TEXT_START\n"
f"Page dimensions: {width:.1f}x{height:.1f} "
f"[Image 0x0 to {width:.1f}x{height:.1f}]\n"
"RAW_TEXT_END"
)
def _coerce_to_pil(img: Union[Image.Image, dict, str]) -> Image.Image:
"""
Normalize Gradio UI input and gradio_client input into a PIL Image.
"""
if isinstance(img, Image.Image):
return img
if isinstance(img, str):
return Image.open(img)
if isinstance(img, dict):
path = img.get("path")
if path:
return Image.open(path)
url = img.get("url")
if url and url.startswith("data:image"):
_, b64 = url.split(",", 1)
return Image.open(BytesIO(base64.b64decode(b64)))
raise ValueError(f"Unsupported image input: {type(img)}")
# -----------------------------------------------------------------------------
# OCR function (API)
# -----------------------------------------------------------------------------
def ocr_image(img: Union[Image.Image, dict, str]) -> tuple[str, str]:
if img is None:
return "No image uploaded.", "0.0s"
start = time.perf_counter()
load_model()
try:
img = _coerce_to_pil(img)
except Exception as e:
return f"Invalid image input: {e}", "0.0s"
img = img.convert("RGB")
img = _resize_max_side(img)
w, h = img.size
# Build prompt
prompt = build_prompt(w, h)
# Encode image for VLM message
buf = BytesIO()
img.save(buf, format="PNG")
image_b64 = base64.b64encode(buf.getvalue()).decode()
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
},
],
}
]
chat_text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = processor(
text=[chat_text],
images=[img],
padding=True,
return_tensors="pt",
)
inputs = {
k: v.to(model.device) if torch.is_tensor(v) else v
for k, v in inputs.items()
}
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
)
prompt_len = inputs["input_ids"].shape[1]
gen_ids = output_ids[:, prompt_len:]
text = processor.tokenizer.batch_decode(
gen_ids, skip_special_tokens=True
)
elapsed = time.perf_counter() - start
return (text[0].strip() if text else "No text extracted.", f"{elapsed:.2f}s")
# -----------------------------------------------------------------------------
# Gradio UI + API
# -----------------------------------------------------------------------------
with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
gr.Markdown(
"# 📖 BookReader OCR API (olmOCR2)\n"
"Upload an image and extract text using **olmOCR-2-7B**.\n\n"
"**API endpoint:** `ocr`"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload image")
run_btn = gr.Button("Run OCR", variant="primary")
with gr.Column():
output = gr.Textbox(label="Extracted text", lines=15)
timing = gr.Textbox(label="Generation time", interactive=False)
run_btn.click(
fn=ocr_image,
inputs=image_input,
outputs=[output, timing],
api_name="ocr",
)
if __name__ == "__main__":
demo.queue().launch(show_error=True)