""" HAFITH — Arabic Manuscript OCR Inference ========================================= Standalone script. Downloads models from HF Hub on first run. Usage: python inference.py manuscript.jpg python inference.py manuscript.jpg --output result.txt python inference.py manuscript.jpg --gemini-key YOUR_KEY """ import argparse import sys import os from pathlib import Path # ── CLI ─────────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser(description="HAFITH Arabic Manuscript OCR") p.add_argument("image", help="Path to manuscript image (JPG/PNG/TIFF)") p.add_argument("--output", "-o", help="Save transcription to this file") p.add_argument("--gemini-key", help="Gemini API key for AI post-correction") p.add_argument("--device", default="auto", help="cuda / cpu / auto (default: auto)") p.add_argument("--models-dir", default=None, help="Local models directory (skips HF download)") return p.parse_args() # ── Model download ───────────────────────────────────────────────────────────── def get_models_dir(local_override=None): if local_override: return local_override try: from huggingface_hub import snapshot_download except ImportError: print("Run: pip install huggingface_hub") sys.exit(1) print("Downloading models from HF Hub (one-time, ~4.3 GB)...") return snapshot_download("mdnaseif/hafith-models") # ── Pipeline ────────────────────────────────────────────────────────────────── def run(image_path, models_dir, device, gemini_key=None): # Add hafith app/ to path — works whether running from repo or standalone repo_app = Path(__file__).parent / "app" if repo_app.exists(): sys.path.insert(0, str(repo_app)) try: from pipeline import ( load_lines_model, load_regions_model, load_ocr, segment, detect_regions, classify_lines_by_region, get_line_images, recognise_lines_batch, ) except ImportError as e: print(f"Import error: {e}") print("Make sure you cloned https://github.com/mdnaseif/hafith_mvp and are running from there.") sys.exit(1) models_dir = Path(models_dir) # ── 1. Load models ───────────────────────────────────────────────────────── print(f"Loading models on {device}...") print(" → RTMDet line segmentation") lines_model = load_lines_model( config_path=str(models_dir / "rtmdet_lines.py"), checkpoint_path=str(models_dir / "lines.pth"), device=device, ) print(" → YOLO region detection") regions_model = load_regions_model(str(models_dir / "regions.pt")) print(" → OCR model (SigLIP2 + Qwen3-0.6B)") ocr_model, processor, tokenizer = load_ocr(str(models_dir / "ocr"), device=device) print("Models loaded.\n") # ── 2. Segment ───────────────────────────────────────────────────────────── print("Segmenting lines...") image_bgr, polygons = segment(lines_model, image_path, conf=0.2) num_lines = len(polygons) if num_lines == 0: print("No text lines detected. Try a higher-resolution scan.") sys.exit(1) print(f"Found {num_lines} lines.") # ── 3. Region classification ─────────────────────────────────────────────── try: region_polys, region_conf = detect_regions(regions_model, image_path, conf=0.5) if region_conf >= 0.75: main_idx, margin_idx, _ = classify_lines_by_region(polygons, region_polys) else: import numpy as np main_idx = sorted(range(num_lines), key=lambda i: np.array(polygons[i])[:, 1].mean()) margin_idx = [] except Exception: import numpy as np main_idx = sorted(range(num_lines), key=lambda i: np.array(polygons[i])[:, 1].mean()) margin_idx = [] # ── 4. OCR ───────────────────────────────────────────────────────────────── print("Recognising text...") line_images = get_line_images(image_bgr, polygons) reading_order = list(main_idx) + list(margin_idx) ordered_images = [line_images[i] for i in reading_order] texts = recognise_lines_batch( ocr_model, processor, tokenizer, ordered_images, device=device, max_patches=512, max_len=64, batch_size=8, ) raw_texts = [""] * num_lines for idx, text in zip(reading_order, texts): raw_texts[idx] = text # ── 5. Gemini correction (optional) ─────────────────────────────────────── final_texts = list(raw_texts) if gemini_key: print("Applying AI post-correction...") os.environ["GEMINI_API_KEY"] = gemini_key from pipeline.correction import init_local_llm, correct_full_text_local corrector = init_local_llm("gemini-2.0-flash") final_texts = correct_full_text_local(corrector, raw_texts, sorted_indices=reading_order) # ── 6. Output ────────────────────────────────────────────────────────────── full_text = "\n".join(final_texts[i] for i in reading_order) print("\n" + "─" * 60) print(full_text) print("─" * 60) print(f"\n{num_lines} lines recognised.") return full_text # ── Entry point ─────────────────────────────────────────────────────────────── if __name__ == "__main__": args = parse_args() import torch if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") else: device = args.device models_dir = get_models_dir(args.models_dir) result = run(args.image, models_dir, device, gemini_key=args.gemini_key) if args.output: Path(args.output).write_text(result, encoding="utf-8") print(f"Saved to {args.output}")