Ryan Chesler
fix missing runtime/build deps and add Slurm helper scriptspyproject.toml: add numpy, huggingface_hub to dependencies andeditables to build-system.requires so editable installs workcleanly without manual workarounds.New scripts:- submit_install_gpu.sh: GPU-node venv creation + package build- submit_inference_viz_gpu.sh: GPU-node batched v2 inference + viz- run_ocr_inference.py: directory-level inference with JSON output
c239c5e | #!/usr/bin/env python3 | |
| """ | |
| Run OCR inference over a directory of images and write: | |
| - images/ : JPG copies of processed input images | |
| - visualizations/ : Annotated images with bounding boxes + text | |
| - results.json : OCR outputs for all images | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import random | |
| from pathlib import Path | |
| from PIL import Image, ImageDraw, ImageFont | |
| from tqdm import tqdm | |
| from nemotron_ocr.inference.pipeline import NemotronOCR | |
| from nemotron_ocr.inference.pipeline_v2 import NemotronOCRV2 | |
| SCRIPT_DIR = Path(__file__).parent.resolve() | |
| FONT_PATH = SCRIPT_DIR / "NotoSansCJKsc-Regular.otf" | |
| logging.getLogger("nemotron_ocr.inference.pipeline_v2").setLevel(logging.WARNING) | |
| logging.getLogger("nemotron_ocr.inference.models.relational").setLevel(logging.WARNING) | |
| def get_font(size: int = 14): | |
| try: | |
| return ImageFont.truetype(str(FONT_PATH), size) | |
| except Exception: | |
| return ImageFont.load_default() | |
| def find_images(images_dir, extensions=(".png", ".jpg", ".jpeg", ".tiff", ".bmp")): | |
| images_dir = Path(images_dir) | |
| image_files = [] | |
| for ext in extensions: | |
| image_files.extend(images_dir.glob(f"**/*{ext}")) | |
| image_files.extend(images_dir.glob(f"**/*{ext.upper()}")) | |
| return sorted(set(image_files)) | |
| def save_visualization(image_path: str, predictions: list, output_path: str): | |
| pil_image = Image.open(image_path).convert("RGB") | |
| draw = ImageDraw.Draw(pil_image) | |
| font = get_font(14) | |
| img_width, img_height = pil_image.size | |
| colors = [ | |
| (255, 0, 0), | |
| (0, 255, 0), | |
| (0, 0, 255), | |
| (255, 165, 0), | |
| (128, 0, 128), | |
| (0, 128, 128), | |
| (255, 192, 203), | |
| (165, 42, 42), | |
| ] | |
| for i, pred in enumerate(predictions): | |
| if isinstance(pred.get("left"), str) and pred["left"] == "nan": | |
| continue | |
| color = colors[i % len(colors)] | |
| left = int(pred["left"] * img_width) | |
| right = int(pred["right"] * img_width) | |
| upper = int(pred.get("upper", pred.get("bottom", 0)) * img_height) | |
| lower = int(pred.get("lower", pred.get("top", 0)) * img_height) | |
| top, bottom = min(upper, lower), max(upper, lower) | |
| draw.rectangle([left, top, right, bottom], outline=color, width=2) | |
| display_text = pred["text"][:50] + "..." if len(pred["text"]) > 50 else pred["text"] | |
| text_y = max(0, top - 20) | |
| try: | |
| text_bbox = draw.textbbox((left, text_y), display_text, font=font) | |
| draw.rectangle( | |
| [text_bbox[0] - 2, text_bbox[1] - 2, text_bbox[2] + 2, text_bbox[3] + 2], | |
| fill=(255, 255, 255, 200), | |
| outline=color, | |
| ) | |
| draw.text((left, text_y), display_text, fill=color, font=font) | |
| except Exception: | |
| draw.text((left, text_y), display_text, fill=color, font=font) | |
| os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True) | |
| pil_image.save(output_path, "JPEG", quality=90) | |
| def save_image_as_jpg(image_path: str, output_path: str): | |
| pil_image = Image.open(image_path).convert("RGB") | |
| pil_image.save(output_path, "JPEG", quality=90) | |
| def _store_result(image_path: Path, predictions: list, images_output_dir: Path, viz_output_dir: Path): | |
| result_key = f"{image_path.stem}.jpg" | |
| row = { | |
| "original_path": str(image_path), | |
| "predictions": predictions, | |
| "num_detections": len(predictions), | |
| "status": "success", | |
| } | |
| dest_image_path = images_output_dir / f"{image_path.stem}.jpg" | |
| if not dest_image_path.exists(): | |
| save_image_as_jpg(str(image_path), str(dest_image_path)) | |
| if predictions: | |
| vis_path = viz_output_dir / f"{image_path.stem}_viz.jpg" | |
| save_visualization(str(image_path), predictions, str(vis_path)) | |
| return result_key, row | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Run OCR inference on image directories") | |
| parser.add_argument("--model_dir", type=str, required=True, help="Path to model checkpoint directory") | |
| parser.add_argument("--images_dir", type=str, required=True, help="Path to directory with input images") | |
| parser.add_argument("--output_dir", type=str, required=True, help="Output directory") | |
| parser.add_argument("--num_samples", type=int, default=0, help="Number of images to process (0 for all)") | |
| parser.add_argument( | |
| "--merge_level", | |
| type=str, | |
| default="sentence", | |
| choices=["word", "sentence", "paragraph"], | |
| help="Text merging granularity", | |
| ) | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling") | |
| parser.add_argument("--include_invalid", action="store_true", help="Include regions marked invalid") | |
| parser.add_argument("--infer_length", type=int, default=1024, help="Detector resolution") | |
| parser.add_argument("--pipeline", type=str, default="v2", choices=("v1", "v2"), help="OCR pipeline version") | |
| parser.add_argument("--batch_size", type=int, default=8, help="Batch size for v2 pipeline") | |
| parser.add_argument("--verbose", action="store_true", help="Enable verbose OCR logs") | |
| args = parser.parse_args() | |
| if args.verbose: | |
| logging.getLogger("nemotron_ocr.inference.pipeline_v2").setLevel(logging.INFO) | |
| output_dir = Path(args.output_dir) | |
| images_output_dir = output_dir / "images" | |
| viz_output_dir = output_dir / "visualizations" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| images_output_dir.mkdir(exist_ok=True) | |
| viz_output_dir.mkdir(exist_ok=True) | |
| image_files = [Path(p) for p in find_images(args.images_dir)] | |
| if not image_files: | |
| print("No images found.") | |
| return | |
| if 0 < args.num_samples < len(image_files): | |
| random.seed(args.seed) | |
| image_files = random.sample(image_files, args.num_samples) | |
| if args.pipeline == "v2": | |
| ocr = NemotronOCRV2( | |
| model_dir=args.model_dir, | |
| infer_length=args.infer_length, | |
| detector_max_batch_size=max(1, args.batch_size), | |
| ) | |
| else: | |
| ocr = NemotronOCR(args.model_dir, infer_length=args.infer_length) | |
| results = {} | |
| if args.pipeline == "v2": | |
| bs = max(1, args.batch_size) | |
| for start in tqdm(range(0, len(image_files), bs), desc="Batches"): | |
| batch_paths = image_files[start:start + bs] | |
| batch_strs = [str(p) for p in batch_paths] | |
| batch_preds = None | |
| try: | |
| batch_preds = ocr( | |
| batch_strs, | |
| merge_level=args.merge_level, | |
| include_invalid=args.include_invalid, | |
| ) | |
| except Exception: | |
| batch_preds = None | |
| if batch_preds is None: | |
| for image_path in batch_paths: | |
| try: | |
| predictions = ocr( | |
| str(image_path), | |
| merge_level=args.merge_level, | |
| include_invalid=args.include_invalid, | |
| ) | |
| key, row = _store_result(image_path, predictions, images_output_dir, viz_output_dir) | |
| results[key] = row | |
| except Exception as e: | |
| import traceback | |
| results[f"{image_path.stem}.jpg"] = { | |
| "original_path": str(image_path), | |
| "predictions": [], | |
| "num_detections": 0, | |
| "status": "error", | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| } | |
| continue | |
| for image_path, predictions in zip(batch_paths, batch_preds): | |
| try: | |
| key, row = _store_result(image_path, predictions, images_output_dir, viz_output_dir) | |
| results[key] = row | |
| except Exception as e: | |
| import traceback | |
| results[f"{image_path.stem}.jpg"] = { | |
| "original_path": str(image_path), | |
| "predictions": [], | |
| "num_detections": 0, | |
| "status": "error", | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| } | |
| else: | |
| for image_path in tqdm(image_files, desc="Processing"): | |
| try: | |
| predictions = ocr( | |
| str(image_path), | |
| merge_level=args.merge_level, | |
| include_invalid=args.include_invalid, | |
| ) | |
| key, row = _store_result(image_path, predictions, images_output_dir, viz_output_dir) | |
| results[key] = row | |
| except Exception as e: | |
| import traceback | |
| results[f"{image_path.stem}.jpg"] = { | |
| "original_path": str(image_path), | |
| "predictions": [], | |
| "num_detections": 0, | |
| "status": "error", | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| } | |
| results_path = output_dir / "results.json" | |
| with open(results_path, "w", encoding="utf-8") as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| success_count = sum(1 for r in results.values() if r["status"] == "success") | |
| total_detections = sum(r["num_detections"] for r in results.values()) | |
| print(f"Done: {success_count}/{len(results)} images, {total_detections} detections -> {output_dir}") | |
| if __name__ == "__main__": | |
| main() | |