arabic-audio-reader-worker / scripts /baseer_ocr_extract.py
Syncre's picture
Deploy Arabic Audio Reader worker
2e1a095 verified
from __future__ import annotations
import argparse
import json
import re
import sys
from pathlib import Path
TAG_RE = re.compile(r"<[^>]+>")
DEFAULT_BASEER_OCR_MODEL = "AbdoTarek/Baseer-OCR-V1.0"
def clean_model_text(text: str) -> str:
text = TAG_RE.sub("\n", text)
text = re.sub(r"```(?:json|html|markdown|text)?", "", text, flags=re.IGNORECASE)
text = text.replace("```", "").strip()
try:
payload = json.loads(text)
except json.JSONDecodeError:
payload = None
if isinstance(payload, dict):
full_text = payload.get("full_text") or payload.get("text") or payload.get("content")
if isinstance(full_text, str):
text = full_text
lines = [line.strip() for line in text.splitlines() if line.strip()]
return "\n".join(lines)
def main() -> None:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
parser = argparse.ArgumentParser(description="Extract Arabic text from page images with Baseer OCR.")
parser.add_argument("--image-dir", required=True, type=Path)
parser.add_argument("--out", required=True, type=Path)
parser.add_argument("--model", default=DEFAULT_BASEER_OCR_MODEL)
parser.add_argument("--max-new-tokens", type=int, default=2048)
args = parser.parse_args()
image_paths = sorted(args.image_dir.glob("*.png"))
total = max(len(image_paths), 1)
print(f"ARABIC_READER_PROGRESS 0 {total}", flush=True)
import torch
from PIL import Image
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
model = Qwen2VLForConditionalGeneration.from_pretrained(
args.model,
torch_dtype="auto",
device_map="auto",
).eval()
processor = AutoProcessor.from_pretrained(args.model)
prompt = (
"Extract ALL visible Arabic text from the document image. "
"Return only JSON with a full_text field. Preserve the original reading order. "
"Do not summarize, translate, or add explanations."
)
pieces: list[str] = []
image_paths = sorted(args.image_dir.glob("*.png"))
total = max(len(image_paths), 1)
for index, image_path in enumerate(image_paths, start=1):
image = Image.open(image_path).convert("RGB")
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are an OCR assistant."}]},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
},
]
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(model.device)
with torch.inference_mode():
output_ids = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
result = processor.batch_decode(
output_ids[:, inputs.input_ids.shape[1] :],
skip_special_tokens=True,
)[0]
page_text = clean_model_text(result)
if page_text:
pieces.append(page_text)
print(f"ARABIC_READER_PROGRESS {index} {total}", flush=True)
args.out.parent.mkdir(parents=True, exist_ok=True)
args.out.write_text("\n\n".join(pieces), encoding="utf-8")
if __name__ == "__main__":
main()