nemotron-ocr-v2 / app.py
emelryan's picture
batched pipeline
dd40da5
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Nemotron OCR v2 — HuggingFace Space (ZeroGPU)."""
import logging
import os
import subprocess
import sys
logging.basicConfig(level=logging.INFO, format="%(name)s %(levelname)s: %(message)s")
try:
import nemotron_ocr
except ImportError:
print("Installing nemotron_ocr wheel...", flush=True)
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "nemotron_ocr-1.0.0-cp312-cp312-linux_x86_64.whl"])
import json
from pathlib import Path
import spaces
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
from PIL import Image, ImageDraw
from nemotron_ocr.inference.pipeline_v2 import NemotronOCRV2
MODELS = {
"Multilingual (en, zh, ja, ko, ru, …)": "multi",
"English-only": "en",
"v1 (legacy, English-only)": "v1",
}
_pipelines: dict[str, NemotronOCRV2] = {}
GROUP_COLORS = [
(76, 175, 80),
(33, 150, 243),
(255, 152, 0),
(156, 39, 176),
(0, 188, 212),
(244, 67, 54),
(139, 195, 74),
(121, 85, 72),
(255, 235, 59),
(63, 81, 181),
(233, 30, 99),
(0, 150, 136),
(255, 87, 34),
(103, 58, 183),
(205, 220, 57),
(96, 125, 139),
]
V1_DEFAULT_CONFIG = {
"num_tokens": 858,
"max_width": 32,
"sequence_length": 32,
"scope": 512,
"coordinate_mode": "RBOX",
"backbone": "regnet_y_8gf",
"charset_size": 855,
}
def _ensure_v1_model_dir() -> str:
"""Download v1 checkpoints and write a model_config.json if missing."""
v1_files = ["detector.pth", "recognizer.pth", "relational.pth", "charset.txt"]
model_dir = None
for f in v1_files:
path = hf_hub_download(repo_id="nvidia/nemotron-ocr-v1", filename=f"checkpoints/{f}")
model_dir = str(Path(path).parent)
config_path = Path(model_dir) / "model_config.json"
if not config_path.exists():
with open(config_path, "w") as fh:
json.dump(V1_DEFAULT_CONFIG, fh)
return model_dir
def _get_pipeline(lang_key: str) -> NemotronOCRV2:
if lang_key not in _pipelines:
if lang_key in ("v1", "legacy"):
model_dir = _ensure_v1_model_dir()
_pipelines[lang_key] = NemotronOCRV2(model_dir=model_dir)
else:
_pipelines[lang_key] = NemotronOCRV2(lang=lang_key)
return _pipelines[lang_key]
def draw_boxes(image: Image.Image, predictions: list[dict]) -> Image.Image:
"""Draw detected text regions colored by group index."""
overlay = image.copy().convert("RGBA")
box_layer = Image.new("RGBA", overlay.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(box_layer)
w, h = image.size
for i, pred in enumerate(predictions):
x0 = pred["left"] * w
y0 = pred["upper"] * h
x1 = pred["right"] * w
y1 = pred["lower"] * h
left = min(x0, x1)
upper = min(y0, y1)
right = max(x0, x1)
lower = max(y0, y1)
if right - left < 1 or lower - upper < 1:
continue
r, g, b = GROUP_COLORS[i % len(GROUP_COLORS)]
draw.rectangle([left, upper, right, lower], fill=(r, g, b, 40), outline=(r, g, b, 220), width=2)
result = Image.alpha_composite(overlay, box_layer)
return result.convert("RGB")
# ---------------------------------------------------------------------------
# Layout-aware spatial text reconstruction
# Adapted from https://github.com/h2oai/h2ogpt/blob/main/src/image_doctr.py
# ---------------------------------------------------------------------------
def _is_same_line(box_a: dict, box_b: dict) -> bool:
"""Two word boxes are on the same line if their vertical midpoints overlap."""
a_mid = (box_a["upper"] + box_a["lower"]) / 2
b_mid = (box_b["upper"] + box_b["lower"]) / 2
a_top, a_bot = min(box_a["upper"], box_a["lower"]), max(box_a["upper"], box_a["lower"])
b_top, b_bot = min(box_b["upper"], box_b["lower"]), max(box_b["upper"], box_b["lower"])
return a_top < b_mid < a_bot and b_top < a_mid < b_bot
def space_layout(words: list[dict]) -> str:
"""Reconstruct the page layout using monospace-style spacing.
Groups words into lines by vertical overlap, sorts left-to-right
within each line, then pads with spaces so horizontal positions
are preserved.
"""
if not words:
return "(no text detected)"
texts = [w.get("text", "") for w in words]
boxes = np.array([
[min(w["left"], w["right"]),
min(w["upper"], w["lower"]),
max(w["left"], w["right"]),
max(w["upper"], w["lower"])]
for w in words
])
sorted_ids = sorted(range(len(boxes)), key=lambda i: boxes[i][1])
remaining = set(sorted_ids)
line_texts: list[list[str]] = []
line_boxes: list[np.ndarray] = []
max_line_char_num = 0
max_line_width = 0.0
while remaining:
anchor = min(remaining)
anchor_box = boxes[anchor]
inline = []
for idx in list(remaining):
mid_y = (boxes[idx][1] + boxes[idx][3]) / 2
if anchor_box[1] <= mid_y <= anchor_box[3]:
inline.append(idx)
for idx in inline:
remaining.discard(idx)
inline.sort(key=lambda i: boxes[i][0])
lt = [texts[i] for i in inline]
lb = boxes[np.array(inline)]
line_texts.append(lt)
line_boxes.append(lb)
joined = " ".join(lt)
if len(joined) > max_line_char_num:
max_line_char_num = len(joined)
max_line_width = float(lb[:, 2].max() - lb[:, 0].min())
char_width = (max_line_width / max_line_char_num) if max_line_char_num > 0 else 0.02
char_width = max(char_width, 0.005)
output_lines = []
for lt, lb in zip(line_texts, line_boxes):
line = ""
for j, (text, box) in enumerate(zip(lt, lb)):
target_col = int(box[0] / char_width)
pad = max(target_col - len(line), 1 if j > 0 else 0)
line += " " * pad + text
output_lines.append(line)
return "\n".join(output_lines)
def format_text(predictions: list[dict], merge_level: str) -> str:
texts = [p.get("text", "") for p in predictions if p.get("text")]
if not texts:
return "(no text detected)"
if merge_level == "paragraph":
return "\n\n".join(texts)
elif merge_level == "sentence":
return "\n".join(texts)
else:
return " ".join(texts)
@spaces.GPU(duration=120)
def run_ocr(image: Image.Image, model_name: str, merge_level: str):
if image is None:
return None, "Please upload an image."
lang_key = MODELS[model_name]
ocr = _get_pipeline(lang_key)
img_array = np.array(image.convert("RGB"))
if merge_level == "layout":
words = ocr(img_array, merge_level="word")
annotated = draw_boxes(image, words)
result_text = space_layout(words)
else:
display_preds = ocr(img_array, merge_level=merge_level)
annotated = draw_boxes(image, display_preds)
result_text = format_text(display_preds, merge_level)
return annotated, result_text
with gr.Blocks(
title="Nemotron OCR v2",
theme=gr.themes.Default(),
) as demo:
gr.Markdown(
"# Nemotron OCR v2\n"
"State-of-the-art multilingual OCR by NVIDIA. "
"Upload an image to extract text with bounding boxes.\n\n"
"*Powered by ZeroGPU — GPU allocated on demand.*"
)
gr.HTML(
"""
<div style="
margin: -2px 0 14px 0;
padding: 8px 10px;
background: rgba(245, 158, 11, 0.10);
border: 1px solid rgba(245, 158, 11, 0.28);
border-radius: 8px;
font-size: 11.5px;
line-height: 1.45;
color: #d1d5db;
">
<strong style="color: #fbbf24;">Disclaimer:</strong>
This demo is for evaluation and testing data only. Do not upload confidential, sensitive, or personal data.
Use of this demo is at your own risk. This demo runs on Hugging Face Spaces and is subject to
<a href="https://huggingface.co/terms-of-service" target="_blank" rel="noopener noreferrer">Hugging Face's Terms of Service</a>
and
<a href="https://huggingface.co/privacy" target="_blank" rel="noopener noreferrer">Privacy Policy</a>.
NVIDIA does not store or retain copies of data you submit or outputs generated as part of this demo
outside this Hugging Face Space.
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
model_choice = gr.Dropdown(
choices=list(MODELS.keys()),
value="Multilingual (en, zh, ja, ko, ru, …)",
label="Model",
)
merge_level = gr.Radio(
choices=["layout", "word", "sentence", "paragraph"],
value="layout",
label="Output Mode",
)
run_btn = gr.Button("Run OCR", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Detected Regions")
output_text = gr.Textbox(
label="Extracted Text",
lines=20,
show_copy_button=True,
)
run_btn.click(
fn=run_ocr,
inputs=[input_image, model_choice, merge_level],
outputs=[output_image, output_text],
)
demo.launch(
server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")),
)