nemotron-ocr-v2 / run_ocr_inference.py
Ryan Chesler
fix missing runtime/build deps and add Slurm helper scriptspyproject.toml: add numpy, huggingface_hub to dependencies andeditables to build-system.requires so editable installs workcleanly without manual workarounds.New scripts:- submit_install_gpu.sh: GPU-node venv creation + package build- submit_inference_viz_gpu.sh: GPU-node batched v2 inference + viz- run_ocr_inference.py: directory-level inference with JSON output
c239c5e
#!/usr/bin/env python3
"""
Run OCR inference over a directory of images and write:
- images/ : JPG copies of processed input images
- visualizations/ : Annotated images with bounding boxes + text
- results.json : OCR outputs for all images
"""
import argparse
import json
import logging
import os
import random
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
from nemotron_ocr.inference.pipeline import NemotronOCR
from nemotron_ocr.inference.pipeline_v2 import NemotronOCRV2
SCRIPT_DIR = Path(__file__).parent.resolve()
FONT_PATH = SCRIPT_DIR / "NotoSansCJKsc-Regular.otf"
logging.getLogger("nemotron_ocr.inference.pipeline_v2").setLevel(logging.WARNING)
logging.getLogger("nemotron_ocr.inference.models.relational").setLevel(logging.WARNING)
def get_font(size: int = 14):
try:
return ImageFont.truetype(str(FONT_PATH), size)
except Exception:
return ImageFont.load_default()
def find_images(images_dir, extensions=(".png", ".jpg", ".jpeg", ".tiff", ".bmp")):
images_dir = Path(images_dir)
image_files = []
for ext in extensions:
image_files.extend(images_dir.glob(f"**/*{ext}"))
image_files.extend(images_dir.glob(f"**/*{ext.upper()}"))
return sorted(set(image_files))
def save_visualization(image_path: str, predictions: list, output_path: str):
pil_image = Image.open(image_path).convert("RGB")
draw = ImageDraw.Draw(pil_image)
font = get_font(14)
img_width, img_height = pil_image.size
colors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 165, 0),
(128, 0, 128),
(0, 128, 128),
(255, 192, 203),
(165, 42, 42),
]
for i, pred in enumerate(predictions):
if isinstance(pred.get("left"), str) and pred["left"] == "nan":
continue
color = colors[i % len(colors)]
left = int(pred["left"] * img_width)
right = int(pred["right"] * img_width)
upper = int(pred.get("upper", pred.get("bottom", 0)) * img_height)
lower = int(pred.get("lower", pred.get("top", 0)) * img_height)
top, bottom = min(upper, lower), max(upper, lower)
draw.rectangle([left, top, right, bottom], outline=color, width=2)
display_text = pred["text"][:50] + "..." if len(pred["text"]) > 50 else pred["text"]
text_y = max(0, top - 20)
try:
text_bbox = draw.textbbox((left, text_y), display_text, font=font)
draw.rectangle(
[text_bbox[0] - 2, text_bbox[1] - 2, text_bbox[2] + 2, text_bbox[3] + 2],
fill=(255, 255, 255, 200),
outline=color,
)
draw.text((left, text_y), display_text, fill=color, font=font)
except Exception:
draw.text((left, text_y), display_text, fill=color, font=font)
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
pil_image.save(output_path, "JPEG", quality=90)
def save_image_as_jpg(image_path: str, output_path: str):
pil_image = Image.open(image_path).convert("RGB")
pil_image.save(output_path, "JPEG", quality=90)
def _store_result(image_path: Path, predictions: list, images_output_dir: Path, viz_output_dir: Path):
result_key = f"{image_path.stem}.jpg"
row = {
"original_path": str(image_path),
"predictions": predictions,
"num_detections": len(predictions),
"status": "success",
}
dest_image_path = images_output_dir / f"{image_path.stem}.jpg"
if not dest_image_path.exists():
save_image_as_jpg(str(image_path), str(dest_image_path))
if predictions:
vis_path = viz_output_dir / f"{image_path.stem}_viz.jpg"
save_visualization(str(image_path), predictions, str(vis_path))
return result_key, row
def main():
parser = argparse.ArgumentParser(description="Run OCR inference on image directories")
parser.add_argument("--model_dir", type=str, required=True, help="Path to model checkpoint directory")
parser.add_argument("--images_dir", type=str, required=True, help="Path to directory with input images")
parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
parser.add_argument("--num_samples", type=int, default=0, help="Number of images to process (0 for all)")
parser.add_argument(
"--merge_level",
type=str,
default="sentence",
choices=["word", "sentence", "paragraph"],
help="Text merging granularity",
)
parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")
parser.add_argument("--include_invalid", action="store_true", help="Include regions marked invalid")
parser.add_argument("--infer_length", type=int, default=1024, help="Detector resolution")
parser.add_argument("--pipeline", type=str, default="v2", choices=("v1", "v2"), help="OCR pipeline version")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for v2 pipeline")
parser.add_argument("--verbose", action="store_true", help="Enable verbose OCR logs")
args = parser.parse_args()
if args.verbose:
logging.getLogger("nemotron_ocr.inference.pipeline_v2").setLevel(logging.INFO)
output_dir = Path(args.output_dir)
images_output_dir = output_dir / "images"
viz_output_dir = output_dir / "visualizations"
output_dir.mkdir(parents=True, exist_ok=True)
images_output_dir.mkdir(exist_ok=True)
viz_output_dir.mkdir(exist_ok=True)
image_files = [Path(p) for p in find_images(args.images_dir)]
if not image_files:
print("No images found.")
return
if 0 < args.num_samples < len(image_files):
random.seed(args.seed)
image_files = random.sample(image_files, args.num_samples)
if args.pipeline == "v2":
ocr = NemotronOCRV2(
model_dir=args.model_dir,
infer_length=args.infer_length,
detector_max_batch_size=max(1, args.batch_size),
)
else:
ocr = NemotronOCR(args.model_dir, infer_length=args.infer_length)
results = {}
if args.pipeline == "v2":
bs = max(1, args.batch_size)
for start in tqdm(range(0, len(image_files), bs), desc="Batches"):
batch_paths = image_files[start:start + bs]
batch_strs = [str(p) for p in batch_paths]
batch_preds = None
try:
batch_preds = ocr(
batch_strs,
merge_level=args.merge_level,
include_invalid=args.include_invalid,
)
except Exception:
batch_preds = None
if batch_preds is None:
for image_path in batch_paths:
try:
predictions = ocr(
str(image_path),
merge_level=args.merge_level,
include_invalid=args.include_invalid,
)
key, row = _store_result(image_path, predictions, images_output_dir, viz_output_dir)
results[key] = row
except Exception as e:
import traceback
results[f"{image_path.stem}.jpg"] = {
"original_path": str(image_path),
"predictions": [],
"num_detections": 0,
"status": "error",
"error": str(e),
"traceback": traceback.format_exc(),
}
continue
for image_path, predictions in zip(batch_paths, batch_preds):
try:
key, row = _store_result(image_path, predictions, images_output_dir, viz_output_dir)
results[key] = row
except Exception as e:
import traceback
results[f"{image_path.stem}.jpg"] = {
"original_path": str(image_path),
"predictions": [],
"num_detections": 0,
"status": "error",
"error": str(e),
"traceback": traceback.format_exc(),
}
else:
for image_path in tqdm(image_files, desc="Processing"):
try:
predictions = ocr(
str(image_path),
merge_level=args.merge_level,
include_invalid=args.include_invalid,
)
key, row = _store_result(image_path, predictions, images_output_dir, viz_output_dir)
results[key] = row
except Exception as e:
import traceback
results[f"{image_path.stem}.jpg"] = {
"original_path": str(image_path),
"predictions": [],
"num_detections": 0,
"status": "error",
"error": str(e),
"traceback": traceback.format_exc(),
}
results_path = output_dir / "results.json"
with open(results_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
success_count = sum(1 for r in results.values() if r["status"] == "success")
total_detections = sum(r["num_detections"] for r in results.values())
print(f"Done: {success_count}/{len(results)} images, {total_detections} detections -> {output_dir}")
if __name__ == "__main__":
main()