File size: 7,110 Bytes
9f4c20d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
HAFITH β€” Arabic Manuscript OCR Inference
=========================================
Standalone script. Downloads models from HF Hub on first run.

Usage:
    python inference.py manuscript.jpg
    python inference.py manuscript.jpg --output result.txt
    python inference.py manuscript.jpg --gemini-key YOUR_KEY
"""

import argparse
import sys
import os
from pathlib import Path


# ── CLI ───────────────────────────────────────────────────────────────────────

def parse_args():
    p = argparse.ArgumentParser(description="HAFITH Arabic Manuscript OCR")
    p.add_argument("image", help="Path to manuscript image (JPG/PNG/TIFF)")
    p.add_argument("--output", "-o", help="Save transcription to this file")
    p.add_argument("--gemini-key", help="Gemini API key for AI post-correction")
    p.add_argument("--device", default="auto", help="cuda / cpu / auto (default: auto)")
    p.add_argument("--models-dir", default=None, help="Local models directory (skips HF download)")
    return p.parse_args()


# ── Model download ─────────────────────────────────────────────────────────────

def get_models_dir(local_override=None):
    if local_override:
        return local_override
    try:
        from huggingface_hub import snapshot_download
    except ImportError:
        print("Run: pip install huggingface_hub")
        sys.exit(1)

    print("Downloading models from HF Hub (one-time, ~4.3 GB)...")
    return snapshot_download("mdnaseif/hafith-models")


# ── Pipeline ──────────────────────────────────────────────────────────────────

def run(image_path, models_dir, device, gemini_key=None):
    # Add hafith app/ to path β€” works whether running from repo or standalone
    repo_app = Path(__file__).parent / "app"
    if repo_app.exists():
        sys.path.insert(0, str(repo_app))

    try:
        from pipeline import (
            load_lines_model, load_regions_model, load_ocr,
            segment, detect_regions, classify_lines_by_region,
            get_line_images, recognise_lines_batch,
        )
    except ImportError as e:
        print(f"Import error: {e}")
        print("Make sure you cloned https://github.com/mdnaseif/hafith_mvp and are running from there.")
        sys.exit(1)

    models_dir = Path(models_dir)

    # ── 1. Load models ─────────────────────────────────────────────────────────
    print(f"Loading models on {device}...")

    print("  β†’ RTMDet line segmentation")
    lines_model = load_lines_model(
        config_path=str(models_dir / "rtmdet_lines.py"),
        checkpoint_path=str(models_dir / "lines.pth"),
        device=device,
    )

    print("  β†’ YOLO region detection")
    regions_model = load_regions_model(str(models_dir / "regions.pt"))

    print("  β†’ OCR model (SigLIP2 + Qwen3-0.6B)")
    ocr_model, processor, tokenizer = load_ocr(str(models_dir / "ocr"), device=device)

    print("Models loaded.\n")

    # ── 2. Segment ─────────────────────────────────────────────────────────────
    print("Segmenting lines...")
    image_bgr, polygons = segment(lines_model, image_path, conf=0.2)
    num_lines = len(polygons)

    if num_lines == 0:
        print("No text lines detected. Try a higher-resolution scan.")
        sys.exit(1)

    print(f"Found {num_lines} lines.")

    # ── 3. Region classification ───────────────────────────────────────────────
    try:
        region_polys, region_conf = detect_regions(regions_model, image_path, conf=0.5)
        if region_conf >= 0.75:
            main_idx, margin_idx, _ = classify_lines_by_region(polygons, region_polys)
        else:
            import numpy as np
            main_idx = sorted(range(num_lines),
                              key=lambda i: np.array(polygons[i])[:, 1].mean())
            margin_idx = []
    except Exception:
        import numpy as np
        main_idx = sorted(range(num_lines),
                          key=lambda i: np.array(polygons[i])[:, 1].mean())
        margin_idx = []

    # ── 4. OCR ─────────────────────────────────────────────────────────────────
    print("Recognising text...")
    line_images    = get_line_images(image_bgr, polygons)
    reading_order  = list(main_idx) + list(margin_idx)
    ordered_images = [line_images[i] for i in reading_order]

    texts = recognise_lines_batch(
        ocr_model, processor, tokenizer,
        ordered_images,
        device=device,
        max_patches=512,
        max_len=64,
        batch_size=8,
    )

    raw_texts = [""] * num_lines
    for idx, text in zip(reading_order, texts):
        raw_texts[idx] = text

    # ── 5. Gemini correction (optional) ───────────────────────────────────────
    final_texts = list(raw_texts)
    if gemini_key:
        print("Applying AI post-correction...")
        os.environ["GEMINI_API_KEY"] = gemini_key
        from pipeline.correction import init_local_llm, correct_full_text_local
        corrector   = init_local_llm("gemini-2.0-flash")
        final_texts = correct_full_text_local(corrector, raw_texts, sorted_indices=reading_order)

    # ── 6. Output ──────────────────────────────────────────────────────────────
    full_text = "\n".join(final_texts[i] for i in reading_order)

    print("\n" + "─" * 60)
    print(full_text)
    print("─" * 60)
    print(f"\n{num_lines} lines recognised.")

    return full_text


# ── Entry point ───────────────────────────────────────────────────────────────

if __name__ == "__main__":
    args = parse_args()

    import torch
    if args.device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Device: {device}")
    else:
        device = args.device

    models_dir = get_models_dir(args.models_dir)
    result     = run(args.image, models_dir, device, gemini_key=args.gemini_key)

    if args.output:
        Path(args.output).write_text(result, encoding="utf-8")
        print(f"Saved to {args.output}")