File size: 4,207 Bytes
2e1a095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import annotations

import argparse
import re
import sys
from pathlib import Path

TAG_RE = re.compile(r"<[^>]+>")
DEFAULT_QARI_OCR_MODEL = "NAMAA-Space/Qari-OCR-0.4.0-VL-4B-Instruct"


def clean_model_text(text: str) -> str:
    text = TAG_RE.sub("\n", text)
    text = re.sub(r"```(?:html|markdown|text)?", "", text, flags=re.IGNORECASE)
    text = text.replace("```", "")
    lines = [line.strip() for line in text.splitlines() if line.strip()]
    return "\n".join(lines)


def main() -> None:
    if hasattr(sys.stdout, "reconfigure"):
        sys.stdout.reconfigure(encoding="utf-8", errors="replace")
    if hasattr(sys.stderr, "reconfigure"):
        sys.stderr.reconfigure(encoding="utf-8", errors="replace")

    parser = argparse.ArgumentParser(description="Extract Arabic text from page images with QARI-OCR.")
    parser.add_argument("--image-dir", required=True, type=Path)
    parser.add_argument("--out", required=True, type=Path)
    parser.add_argument("--model", default=DEFAULT_QARI_OCR_MODEL)
    parser.add_argument("--max-new-tokens", type=int, default=2048)
    args = parser.parse_args()

    image_paths = sorted(args.image_dir.glob("*.png"))
    total = max(len(image_paths), 1)
    print(f"ARABIC_READER_PROGRESS 0 {total}", flush=True)

    import torch
    from transformers import AutoProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if device == "cuda" else torch.float32
    processor = AutoProcessor.from_pretrained(args.model)
    prompt = (
        "Extract only the readable Arabic text from this scanned document page. "
        "Keep the natural reading order. Do not summarize or translate."
    )
    uses_qwen3_vl = "qari-ocr-0.4" in args.model.lower() or "qwen3" in args.model.lower()
    if uses_qwen3_vl:
        from transformers import AutoModelForVision2Seq

        model = AutoModelForVision2Seq.from_pretrained(
            args.model,
            torch_dtype=dtype,
            device_map="auto" if device == "cuda" else None,
        )
    else:
        from qwen_vl_utils import process_vision_info
        from transformers import AutoModelForImageTextToText

        model = AutoModelForImageTextToText.from_pretrained(
            args.model,
            torch_dtype=dtype,
            device_map="auto" if device == "cuda" else None,
        )
    if device == "cpu":
        model.to(device)

    pieces: list[str] = []
    for index, image_path in enumerate(image_paths, start=1):
        image_reference = str(image_path.resolve()) if uses_qwen3_vl else f"file://{image_path.resolve().as_posix()}"
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_reference},
                    {"type": "text", "text": prompt},
                ],
            }
        ]
        if uses_qwen3_vl:
            inputs = processor.apply_chat_template(
                messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt",
            ).to(model.device)
        else:
            text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            ).to(model.device)
        generated_ids = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
        generated_ids = generated_ids[:, inputs.input_ids.shape[1] :]
        output = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        page_text = clean_model_text(output)
        if page_text:
            pieces.append(page_text)
        print(f"ARABIC_READER_PROGRESS {index} {total}", flush=True)

    args.out.parent.mkdir(parents=True, exist_ok=True)
    args.out.write_text("\n\n".join(pieces), encoding="utf-8")


if __name__ == "__main__":
    main()