| """ |
| HAFITH β Arabic Manuscript OCR Inference |
| ========================================= |
| Standalone script. Downloads models from HF Hub on first run. |
| |
| Usage: |
| python inference.py manuscript.jpg |
| python inference.py manuscript.jpg --output result.txt |
| python inference.py manuscript.jpg --gemini-key YOUR_KEY |
| """ |
|
|
| import argparse |
| import sys |
| import os |
| from pathlib import Path |
|
|
|
|
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="HAFITH Arabic Manuscript OCR") |
| p.add_argument("image", help="Path to manuscript image (JPG/PNG/TIFF)") |
| p.add_argument("--output", "-o", help="Save transcription to this file") |
| p.add_argument("--gemini-key", help="Gemini API key for AI post-correction") |
| p.add_argument("--device", default="auto", help="cuda / cpu / auto (default: auto)") |
| p.add_argument("--models-dir", default=None, help="Local models directory (skips HF download)") |
| return p.parse_args() |
|
|
|
|
| |
|
|
| def get_models_dir(local_override=None): |
| if local_override: |
| return local_override |
| try: |
| from huggingface_hub import snapshot_download |
| except ImportError: |
| print("Run: pip install huggingface_hub") |
| sys.exit(1) |
|
|
| print("Downloading models from HF Hub (one-time, ~4.3 GB)...") |
| return snapshot_download("mdnaseif/hafith-models") |
|
|
|
|
| |
|
|
| def run(image_path, models_dir, device, gemini_key=None): |
| |
| repo_app = Path(__file__).parent / "app" |
| if repo_app.exists(): |
| sys.path.insert(0, str(repo_app)) |
|
|
| try: |
| from pipeline import ( |
| load_lines_model, load_regions_model, load_ocr, |
| segment, detect_regions, classify_lines_by_region, |
| get_line_images, recognise_lines_batch, |
| ) |
| except ImportError as e: |
| print(f"Import error: {e}") |
| print("Make sure you cloned https://github.com/mdnaseif/hafith_mvp and are running from there.") |
| sys.exit(1) |
|
|
| models_dir = Path(models_dir) |
|
|
| |
| print(f"Loading models on {device}...") |
|
|
| print(" β RTMDet line segmentation") |
| lines_model = load_lines_model( |
| config_path=str(models_dir / "rtmdet_lines.py"), |
| checkpoint_path=str(models_dir / "lines.pth"), |
| device=device, |
| ) |
|
|
| print(" β YOLO region detection") |
| regions_model = load_regions_model(str(models_dir / "regions.pt")) |
|
|
| print(" β OCR model (SigLIP2 + Qwen3-0.6B)") |
| ocr_model, processor, tokenizer = load_ocr(str(models_dir / "ocr"), device=device) |
|
|
| print("Models loaded.\n") |
|
|
| |
| print("Segmenting lines...") |
| image_bgr, polygons = segment(lines_model, image_path, conf=0.2) |
| num_lines = len(polygons) |
|
|
| if num_lines == 0: |
| print("No text lines detected. Try a higher-resolution scan.") |
| sys.exit(1) |
|
|
| print(f"Found {num_lines} lines.") |
|
|
| |
| try: |
| region_polys, region_conf = detect_regions(regions_model, image_path, conf=0.5) |
| if region_conf >= 0.75: |
| main_idx, margin_idx, _ = classify_lines_by_region(polygons, region_polys) |
| else: |
| import numpy as np |
| main_idx = sorted(range(num_lines), |
| key=lambda i: np.array(polygons[i])[:, 1].mean()) |
| margin_idx = [] |
| except Exception: |
| import numpy as np |
| main_idx = sorted(range(num_lines), |
| key=lambda i: np.array(polygons[i])[:, 1].mean()) |
| margin_idx = [] |
|
|
| |
| print("Recognising text...") |
| line_images = get_line_images(image_bgr, polygons) |
| reading_order = list(main_idx) + list(margin_idx) |
| ordered_images = [line_images[i] for i in reading_order] |
|
|
| texts = recognise_lines_batch( |
| ocr_model, processor, tokenizer, |
| ordered_images, |
| device=device, |
| max_patches=512, |
| max_len=64, |
| batch_size=8, |
| ) |
|
|
| raw_texts = [""] * num_lines |
| for idx, text in zip(reading_order, texts): |
| raw_texts[idx] = text |
|
|
| |
| final_texts = list(raw_texts) |
| if gemini_key: |
| print("Applying AI post-correction...") |
| os.environ["GEMINI_API_KEY"] = gemini_key |
| from pipeline.correction import init_local_llm, correct_full_text_local |
| corrector = init_local_llm("gemini-2.0-flash") |
| final_texts = correct_full_text_local(corrector, raw_texts, sorted_indices=reading_order) |
|
|
| |
| full_text = "\n".join(final_texts[i] for i in reading_order) |
|
|
| print("\n" + "β" * 60) |
| print(full_text) |
| print("β" * 60) |
| print(f"\n{num_lines} lines recognised.") |
|
|
| return full_text |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| args = parse_args() |
|
|
| import torch |
| if args.device == "auto": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Device: {device}") |
| else: |
| device = args.device |
|
|
| models_dir = get_models_dir(args.models_dir) |
| result = run(args.image, models_dir, device, gemini_key=args.gemini_key) |
|
|
| if args.output: |
| Path(args.output).write_text(result, encoding="utf-8") |
| print(f"Saved to {args.output}") |
|
|