Spaces:
Sleeping
Sleeping
File size: 5,696 Bytes
1ab4133 6ba0575 097f30f 1ab4133 0597204 1ab4133 778e3eb 6ba0575 b2fc952 6ba0575 778e3eb 1ab4133 778e3eb 1ab4133 778e3eb 1ab4133 778e3eb 1ab4133 097f30f 1ab4133 778e3eb 1ab4133 778e3eb 1ab4133 778e3eb 1ab4133 0597204 097f30f 778e3eb 097f30f 778e3eb 097f30f 778e3eb 097f30f 778e3eb 097f30f c70f99d b2fc952 a82af97 778e3eb c70f99d a82af97 097f30f 778e3eb 097f30f c70f99d 778e3eb a82af97 778e3eb a82af97 778e3eb a82af97 778e3eb a82af97 c70f99d a82af97 240f8ec a82af97 c70f99d a82af97 097f30f 778e3eb 66fe0e6 c70f99d a82af97 b2fc952 240f8ec a82af97 c70f99d 778e3eb b2fc952 778e3eb 1ab4133 778e3eb 1ab4133 778e3eb 23bf303 1ab4133 92e866a 6ba0575 92e866a b2fc952 92e866a 1ab4133 778e3eb b2fc952 23bf303 1ab4133 778e3eb 6ba0575 ffb2e43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import os
import base64
from io import BytesIO
import warnings
import time
from typing import Union
import torch
from PIL import Image
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
# -----------------------------------------------------------------------------
# Environment + warnings (quiet startup)
# -----------------------------------------------------------------------------
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
# -----------------------------------------------------------------------------
# Model config
# -----------------------------------------------------------------------------
MODEL_ID = "allenai/olmOCR-2-7B-1025"
processor = None
model = None
def load_model():
"""Lazy-load model so Space boots fast."""
global processor, model
if processor is not None and model is not None:
return
processor = AutoProcessor.from_pretrained(
MODEL_ID,
trust_remote_code=True,
)
model = AutoModelForVision2Seq.from_pretrained(
MODEL_ID,
dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
).eval()
print("✅ olmOCR-2 model loaded")
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def _resize_max_side(img: Image.Image, max_side: int = 896) -> Image.Image:
w, h = img.size
m = max(w, h)
if m <= max_side:
return img
scale = max_side / m
return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
def build_prompt(width: int, height: int) -> str:
return (
"Extract all readable text from this page image.\n"
"Return ONLY the extracted text (no explanations, no markdown).\n"
"Do not hallucinate.\n"
"RAW_TEXT_START\n"
f"Page dimensions: {width:.1f}x{height:.1f} "
f"[Image 0x0 to {width:.1f}x{height:.1f}]\n"
"RAW_TEXT_END"
)
def _coerce_to_pil(img: Union[Image.Image, dict, str]) -> Image.Image:
"""
Normalize Gradio UI input and gradio_client input into a PIL Image.
"""
if isinstance(img, Image.Image):
return img
if isinstance(img, str):
return Image.open(img)
if isinstance(img, dict):
path = img.get("path")
if path:
return Image.open(path)
url = img.get("url")
if url and url.startswith("data:image"):
_, b64 = url.split(",", 1)
return Image.open(BytesIO(base64.b64decode(b64)))
raise ValueError(f"Unsupported image input: {type(img)}")
# -----------------------------------------------------------------------------
# OCR function (API)
# -----------------------------------------------------------------------------
def ocr_image(img: Union[Image.Image, dict, str]) -> tuple[str, str]:
if img is None:
return "No image uploaded.", "0.0s"
start = time.perf_counter()
load_model()
try:
img = _coerce_to_pil(img)
except Exception as e:
return f"Invalid image input: {e}", "0.0s"
img = img.convert("RGB")
img = _resize_max_side(img)
w, h = img.size
# Build prompt
prompt = build_prompt(w, h)
# Encode image for VLM message
buf = BytesIO()
img.save(buf, format="PNG")
image_b64 = base64.b64encode(buf.getvalue()).decode()
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
},
],
}
]
chat_text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = processor(
text=[chat_text],
images=[img],
padding=True,
return_tensors="pt",
)
inputs = {
k: v.to(model.device) if torch.is_tensor(v) else v
for k, v in inputs.items()
}
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
)
prompt_len = inputs["input_ids"].shape[1]
gen_ids = output_ids[:, prompt_len:]
text = processor.tokenizer.batch_decode(
gen_ids, skip_special_tokens=True
)
elapsed = time.perf_counter() - start
return (text[0].strip() if text else "No text extracted.", f"{elapsed:.2f}s")
# -----------------------------------------------------------------------------
# Gradio UI + API
# -----------------------------------------------------------------------------
with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
gr.Markdown(
"# 📖 BookReader OCR API (olmOCR2)\n"
"Upload an image and extract text using **olmOCR-2-7B**.\n\n"
"**API endpoint:** `ocr`"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload image")
run_btn = gr.Button("Run OCR", variant="primary")
with gr.Column():
output = gr.Textbox(label="Extracted text", lines=15)
timing = gr.Textbox(label="Generation time", interactive=False)
run_btn.click(
fn=ocr_image,
inputs=image_input,
outputs=[output, timing],
api_name="ocr",
)
if __name__ == "__main__":
demo.queue().launch(show_error=True)
|