#!/usr/bin/env python3 """ Run OCR inference over a directory of images and write: - images/ : JPG copies of processed input images - visualizations/ : Annotated images with bounding boxes + text - results.json : OCR outputs for all images """ import argparse import json import logging import os import random from pathlib import Path from PIL import Image, ImageDraw, ImageFont from tqdm import tqdm from nemotron_ocr.inference.pipeline import NemotronOCR from nemotron_ocr.inference.pipeline_v2 import NemotronOCRV2 SCRIPT_DIR = Path(__file__).parent.resolve() FONT_PATH = SCRIPT_DIR / "NotoSansCJKsc-Regular.otf" logging.getLogger("nemotron_ocr.inference.pipeline_v2").setLevel(logging.WARNING) logging.getLogger("nemotron_ocr.inference.models.relational").setLevel(logging.WARNING) def get_font(size: int = 14): try: return ImageFont.truetype(str(FONT_PATH), size) except Exception: return ImageFont.load_default() def find_images(images_dir, extensions=(".png", ".jpg", ".jpeg", ".tiff", ".bmp")): images_dir = Path(images_dir) image_files = [] for ext in extensions: image_files.extend(images_dir.glob(f"**/*{ext}")) image_files.extend(images_dir.glob(f"**/*{ext.upper()}")) return sorted(set(image_files)) def save_visualization(image_path: str, predictions: list, output_path: str): pil_image = Image.open(image_path).convert("RGB") draw = ImageDraw.Draw(pil_image) font = get_font(14) img_width, img_height = pil_image.size colors = [ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (128, 0, 128), (0, 128, 128), (255, 192, 203), (165, 42, 42), ] for i, pred in enumerate(predictions): if isinstance(pred.get("left"), str) and pred["left"] == "nan": continue color = colors[i % len(colors)] left = int(pred["left"] * img_width) right = int(pred["right"] * img_width) upper = int(pred.get("upper", pred.get("bottom", 0)) * img_height) lower = int(pred.get("lower", pred.get("top", 0)) * img_height) top, bottom = min(upper, lower), max(upper, lower) draw.rectangle([left, top, right, bottom], outline=color, width=2) display_text = pred["text"][:50] + "..." if len(pred["text"]) > 50 else pred["text"] text_y = max(0, top - 20) try: text_bbox = draw.textbbox((left, text_y), display_text, font=font) draw.rectangle( [text_bbox[0] - 2, text_bbox[1] - 2, text_bbox[2] + 2, text_bbox[3] + 2], fill=(255, 255, 255, 200), outline=color, ) draw.text((left, text_y), display_text, fill=color, font=font) except Exception: draw.text((left, text_y), display_text, fill=color, font=font) os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True) pil_image.save(output_path, "JPEG", quality=90) def save_image_as_jpg(image_path: str, output_path: str): pil_image = Image.open(image_path).convert("RGB") pil_image.save(output_path, "JPEG", quality=90) def _store_result(image_path: Path, predictions: list, images_output_dir: Path, viz_output_dir: Path): result_key = f"{image_path.stem}.jpg" row = { "original_path": str(image_path), "predictions": predictions, "num_detections": len(predictions), "status": "success", } dest_image_path = images_output_dir / f"{image_path.stem}.jpg" if not dest_image_path.exists(): save_image_as_jpg(str(image_path), str(dest_image_path)) if predictions: vis_path = viz_output_dir / f"{image_path.stem}_viz.jpg" save_visualization(str(image_path), predictions, str(vis_path)) return result_key, row def main(): parser = argparse.ArgumentParser(description="Run OCR inference on image directories") parser.add_argument("--model_dir", type=str, required=True, help="Path to model checkpoint directory") parser.add_argument("--images_dir", type=str, required=True, help="Path to directory with input images") parser.add_argument("--output_dir", type=str, required=True, help="Output directory") parser.add_argument("--num_samples", type=int, default=0, help="Number of images to process (0 for all)") parser.add_argument( "--merge_level", type=str, default="sentence", choices=["word", "sentence", "paragraph"], help="Text merging granularity", ) parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling") parser.add_argument("--include_invalid", action="store_true", help="Include regions marked invalid") parser.add_argument("--infer_length", type=int, default=1024, help="Detector resolution") parser.add_argument("--pipeline", type=str, default="v2", choices=("v1", "v2"), help="OCR pipeline version") parser.add_argument("--batch_size", type=int, default=8, help="Batch size for v2 pipeline") parser.add_argument("--verbose", action="store_true", help="Enable verbose OCR logs") args = parser.parse_args() if args.verbose: logging.getLogger("nemotron_ocr.inference.pipeline_v2").setLevel(logging.INFO) output_dir = Path(args.output_dir) images_output_dir = output_dir / "images" viz_output_dir = output_dir / "visualizations" output_dir.mkdir(parents=True, exist_ok=True) images_output_dir.mkdir(exist_ok=True) viz_output_dir.mkdir(exist_ok=True) image_files = [Path(p) for p in find_images(args.images_dir)] if not image_files: print("No images found.") return if 0 < args.num_samples < len(image_files): random.seed(args.seed) image_files = random.sample(image_files, args.num_samples) if args.pipeline == "v2": ocr = NemotronOCRV2( model_dir=args.model_dir, infer_length=args.infer_length, detector_max_batch_size=max(1, args.batch_size), ) else: ocr = NemotronOCR(args.model_dir, infer_length=args.infer_length) results = {} if args.pipeline == "v2": bs = max(1, args.batch_size) for start in tqdm(range(0, len(image_files), bs), desc="Batches"): batch_paths = image_files[start:start + bs] batch_strs = [str(p) for p in batch_paths] batch_preds = None try: batch_preds = ocr( batch_strs, merge_level=args.merge_level, include_invalid=args.include_invalid, ) except Exception: batch_preds = None if batch_preds is None: for image_path in batch_paths: try: predictions = ocr( str(image_path), merge_level=args.merge_level, include_invalid=args.include_invalid, ) key, row = _store_result(image_path, predictions, images_output_dir, viz_output_dir) results[key] = row except Exception as e: import traceback results[f"{image_path.stem}.jpg"] = { "original_path": str(image_path), "predictions": [], "num_detections": 0, "status": "error", "error": str(e), "traceback": traceback.format_exc(), } continue for image_path, predictions in zip(batch_paths, batch_preds): try: key, row = _store_result(image_path, predictions, images_output_dir, viz_output_dir) results[key] = row except Exception as e: import traceback results[f"{image_path.stem}.jpg"] = { "original_path": str(image_path), "predictions": [], "num_detections": 0, "status": "error", "error": str(e), "traceback": traceback.format_exc(), } else: for image_path in tqdm(image_files, desc="Processing"): try: predictions = ocr( str(image_path), merge_level=args.merge_level, include_invalid=args.include_invalid, ) key, row = _store_result(image_path, predictions, images_output_dir, viz_output_dir) results[key] = row except Exception as e: import traceback results[f"{image_path.stem}.jpg"] = { "original_path": str(image_path), "predictions": [], "num_detections": 0, "status": "error", "error": str(e), "traceback": traceback.format_exc(), } results_path = output_dir / "results.json" with open(results_path, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) success_count = sum(1 for r in results.values() if r["status"] == "success") total_detections = sum(r["num_detections"] for r in results.values()) print(f"Done: {success_count}/{len(results)} images, {total_detections} detections -> {output_dir}") if __name__ == "__main__": main()