arabic-audio-reader-worker / scripts /arabic_glm_ocr_extract.py
Syncre's picture
Deploy Arabic Audio Reader worker
2e1a095 verified
from __future__ import annotations
import argparse
import re
import sys
from pathlib import Path
TAG_RE = re.compile(r"<[^>]+>")
DEFAULT_ARABIC_GLM_OCR_MODEL = "sherif1313/Arabic-GLM-OCR-v2"
def clean_model_text(text: str) -> str:
text = TAG_RE.sub("\n", text)
text = re.sub(r"```(?:html|markdown|text|json)?", "", text, flags=re.IGNORECASE)
text = text.replace("```", "")
lines = [line.strip() for line in text.splitlines() if line.strip()]
return "\n".join(lines)
def main() -> None:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
parser = argparse.ArgumentParser(description="Extract Arabic text from page images with Arabic-GLM-OCR.")
parser.add_argument("--image-dir", required=True, type=Path)
parser.add_argument("--out", required=True, type=Path)
parser.add_argument("--model", default=DEFAULT_ARABIC_GLM_OCR_MODEL)
parser.add_argument("--max-new-tokens", type=int, default=2048)
args = parser.parse_args()
image_paths = sorted(args.image_dir.glob("*.png"))
total = max(len(image_paths), 1)
print(f"ARABIC_READER_PROGRESS 0 {total}", flush=True)
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
processor = AutoProcessor.from_pretrained(args.model)
model = AutoModelForImageTextToText.from_pretrained(
args.model,
torch_dtype=dtype,
device_map="auto" if device == "cuda" else None,
)
if device == "cpu":
model.to(device)
prompt = (
"Extract the Arabic text exactly as it appears on this scanned page. "
"Preserve reading order. Do not summarize, translate, explain, or correct the text."
)
pieces: list[str] = []
image_paths = sorted(args.image_dir.glob("*.png"))
total = max(len(image_paths), 1)
for index, image_path in enumerate(image_paths, start=1):
image = Image.open(image_path).convert("RGB")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[image], return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
result = processor.decode(output[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True)
page_text = clean_model_text(result)
if page_text:
pieces.append(page_text)
print(f"ARABIC_READER_PROGRESS {index} {total}", flush=True)
args.out.parent.mkdir(parents=True, exist_ok=True)
args.out.write_text("\n\n".join(pieces), encoding="utf-8")
if __name__ == "__main__":
main()