hafith-models / inference.py
mdnaseif's picture
Upload inference.py with huggingface_hub
9f4c20d verified
"""
HAFITH β€” Arabic Manuscript OCR Inference
=========================================
Standalone script. Downloads models from HF Hub on first run.
Usage:
python inference.py manuscript.jpg
python inference.py manuscript.jpg --output result.txt
python inference.py manuscript.jpg --gemini-key YOUR_KEY
"""
import argparse
import sys
import os
from pathlib import Path
# ── CLI ───────────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser(description="HAFITH Arabic Manuscript OCR")
p.add_argument("image", help="Path to manuscript image (JPG/PNG/TIFF)")
p.add_argument("--output", "-o", help="Save transcription to this file")
p.add_argument("--gemini-key", help="Gemini API key for AI post-correction")
p.add_argument("--device", default="auto", help="cuda / cpu / auto (default: auto)")
p.add_argument("--models-dir", default=None, help="Local models directory (skips HF download)")
return p.parse_args()
# ── Model download ─────────────────────────────────────────────────────────────
def get_models_dir(local_override=None):
if local_override:
return local_override
try:
from huggingface_hub import snapshot_download
except ImportError:
print("Run: pip install huggingface_hub")
sys.exit(1)
print("Downloading models from HF Hub (one-time, ~4.3 GB)...")
return snapshot_download("mdnaseif/hafith-models")
# ── Pipeline ──────────────────────────────────────────────────────────────────
def run(image_path, models_dir, device, gemini_key=None):
# Add hafith app/ to path β€” works whether running from repo or standalone
repo_app = Path(__file__).parent / "app"
if repo_app.exists():
sys.path.insert(0, str(repo_app))
try:
from pipeline import (
load_lines_model, load_regions_model, load_ocr,
segment, detect_regions, classify_lines_by_region,
get_line_images, recognise_lines_batch,
)
except ImportError as e:
print(f"Import error: {e}")
print("Make sure you cloned https://github.com/mdnaseif/hafith_mvp and are running from there.")
sys.exit(1)
models_dir = Path(models_dir)
# ── 1. Load models ─────────────────────────────────────────────────────────
print(f"Loading models on {device}...")
print(" β†’ RTMDet line segmentation")
lines_model = load_lines_model(
config_path=str(models_dir / "rtmdet_lines.py"),
checkpoint_path=str(models_dir / "lines.pth"),
device=device,
)
print(" β†’ YOLO region detection")
regions_model = load_regions_model(str(models_dir / "regions.pt"))
print(" β†’ OCR model (SigLIP2 + Qwen3-0.6B)")
ocr_model, processor, tokenizer = load_ocr(str(models_dir / "ocr"), device=device)
print("Models loaded.\n")
# ── 2. Segment ─────────────────────────────────────────────────────────────
print("Segmenting lines...")
image_bgr, polygons = segment(lines_model, image_path, conf=0.2)
num_lines = len(polygons)
if num_lines == 0:
print("No text lines detected. Try a higher-resolution scan.")
sys.exit(1)
print(f"Found {num_lines} lines.")
# ── 3. Region classification ───────────────────────────────────────────────
try:
region_polys, region_conf = detect_regions(regions_model, image_path, conf=0.5)
if region_conf >= 0.75:
main_idx, margin_idx, _ = classify_lines_by_region(polygons, region_polys)
else:
import numpy as np
main_idx = sorted(range(num_lines),
key=lambda i: np.array(polygons[i])[:, 1].mean())
margin_idx = []
except Exception:
import numpy as np
main_idx = sorted(range(num_lines),
key=lambda i: np.array(polygons[i])[:, 1].mean())
margin_idx = []
# ── 4. OCR ─────────────────────────────────────────────────────────────────
print("Recognising text...")
line_images = get_line_images(image_bgr, polygons)
reading_order = list(main_idx) + list(margin_idx)
ordered_images = [line_images[i] for i in reading_order]
texts = recognise_lines_batch(
ocr_model, processor, tokenizer,
ordered_images,
device=device,
max_patches=512,
max_len=64,
batch_size=8,
)
raw_texts = [""] * num_lines
for idx, text in zip(reading_order, texts):
raw_texts[idx] = text
# ── 5. Gemini correction (optional) ───────────────────────────────────────
final_texts = list(raw_texts)
if gemini_key:
print("Applying AI post-correction...")
os.environ["GEMINI_API_KEY"] = gemini_key
from pipeline.correction import init_local_llm, correct_full_text_local
corrector = init_local_llm("gemini-2.0-flash")
final_texts = correct_full_text_local(corrector, raw_texts, sorted_indices=reading_order)
# ── 6. Output ──────────────────────────────────────────────────────────────
full_text = "\n".join(final_texts[i] for i in reading_order)
print("\n" + "─" * 60)
print(full_text)
print("─" * 60)
print(f"\n{num_lines} lines recognised.")
return full_text
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
args = parse_args()
import torch
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
else:
device = args.device
models_dir = get_models_dir(args.models_dir)
result = run(args.image, models_dir, device, gemini_key=args.gemini_key)
if args.output:
Path(args.output).write_text(result, encoding="utf-8")
print(f"Saved to {args.output}")