sbb-binarization-onnx / example_gpu_pipeline.py
nathansut1's picture
Upload example_gpu_pipeline.py with huggingface_hub
d8cc6e6 verified
"""
Full GPU pipeline example for the SBB binarization ONNX model.
Shows how to keep the entire image processing chain on the GPU using CuPy,
with only the JPEG decode and TIFF save happening on CPU. This is the
approach you'd use for a production pipeline where throughput matters.
pip install onnxruntime-gpu cupy-cuda12x numpy Pillow
python3 example_gpu_pipeline.py input.jpg output.tif
On first run, TensorRT builds an optimized engine (~60-90s). This is
cached in ./trt_cache/ and reused on subsequent runs.
"""
import sys
import os
import numpy as np
import cupy as cp
import onnxruntime as ort
from PIL import Image
MODEL = "model_convtranspose.onnx"
PATCH_SIZE = 448
BATCH_SIZE = 64
# ── Normalization LUT ────────────────────────────────────────────────────────
# The original TF model normalizes with: np.array(img) / 255.0 which does
# float64 division then truncates to float32. Doing float32 division directly
# gives different rounding for some values (off by 1 ULP), which can flip
# pixels at the binarization threshold. This LUT preserves the exact behavior.
_NORM_LUT = cp.array(
np.array([np.float32(np.float64(i) / 255.0) for i in range(256)],
dtype=np.float32)
)
def create_session(model_path):
"""Create an ONNX Runtime session with TensorRT backend.
TensorRT compiles the model into an optimized GPU engine on first run.
The engine is cached to disk so subsequent runs start in ~2 seconds.
"""
cache_dir = "./trt_cache"
os.makedirs(cache_dir, exist_ok=True)
return ort.InferenceSession(model_path, providers=[
("TensorrtExecutionProvider", {
"device_id": 0,
"trt_fp16_enable": False, # FP32 for accuracy
"trt_engine_cache_enable": True,
"trt_engine_cache_path": cache_dir,
"trt_builder_optimization_level": 3,
}),
("CUDAExecutionProvider", {"device_id": 0}),
])
def extract_patches_gpu(img_gpu, patch_size):
"""Extract non-overlapping patches on GPU. Zero-pads edges."""
h, w = img_gpu.shape[:2]
positions = [(x, y) for y in range(0, h, patch_size)
for x in range(0, w, patch_size)]
patches = cp.zeros((len(positions), patch_size, patch_size, 3), dtype=cp.uint8)
for i, (x, y) in enumerate(positions):
ph = min(patch_size, h - y)
pw = min(patch_size, w - x)
patches[i, :ph, :pw, :] = img_gpu[y:y+ph, x:x+pw, :]
return patches, positions
def infer_patches(session, patches_uint8):
"""Normalize and run inference, one batch at a time.
Normalizing per-batch (64 patches = 154MB) instead of all at once
(500+ patches = 2.6GB) avoids GPU memory fragmentation.
"""
inp = session.get_inputs()[0].name
out = session.get_outputs()[0].name
n = patches_uint8.shape[0]
out_ch = session.get_outputs()[0].shape[3] or 2
all_output = cp.zeros((n, PATCH_SIZE, PATCH_SIZE, out_ch), dtype=cp.float32)
for i in range(0, n, BATCH_SIZE):
end = min(i + BATCH_SIZE, n)
# Normalize on GPU via LUT (uint8 -> float32, 8ms per batch)
batch_float = _NORM_LUT[patches_uint8[i:end].astype(cp.int32)]
# Transfer to CPU for ONNX Runtime inference
result = session.run([out], {inp: batch_float.get()})[0]
# Transfer result back to GPU for post-processing
all_output[i:end] = cp.asarray(result)
return all_output
def postprocess_gpu(output):
"""Extract foreground probability, threshold, binarize β€” all on GPU."""
probs = output[:, :, :, 1] # channel 1 = foreground
quantized = (probs * 255.0).astype(cp.uint8)
return cp.where(quantized <= 128, cp.uint8(255), cp.uint8(0))
def reconstruct_gpu(patches, positions, width, height):
"""Reconstruct full image from patches with overlap averaging β€” all on GPU."""
result = cp.zeros((height, width), dtype=cp.float32)
weight = cp.zeros((height, width), dtype=cp.float32)
for i, (x, y) in enumerate(positions):
ah = min(PATCH_SIZE, height - y)
aw = min(PATCH_SIZE, width - x)
result[y:y+ah, x:x+aw] += patches[i, :ah, :aw].astype(cp.float32)
weight[y:y+ah, x:x+aw] += 1.0
return (result / cp.maximum(weight, 1.0)).astype(cp.uint8)
def binarize_image(input_path, output_path, model_path=MODEL):
"""Full pipeline: JPEG in -> binarized TIFF out.
Data flow:
CPU: decode JPEG
CPU -> GPU: upload image (~5ms)
GPU: extract patches (~7ms)
GPU -> CPU -> GPU: normalize, infer, collect (~175ms per batch)
GPU: threshold + binarize (~1ms)
GPU: reconstruct from patches (~13ms)
GPU -> CPU: download result (~2ms)
CPU: save Group4 TIFF
"""
# CPU: decode
img = np.array(Image.open(input_path).convert("RGB"))
h, w = img.shape[:2]
# CPU -> GPU
img_gpu = cp.asarray(img)
# GPU: patch extraction
patches, positions = extract_patches_gpu(img_gpu, PATCH_SIZE)
# Inference (normalize on GPU, infer via ORT, results back on GPU)
session = create_session(model_path)
output = infer_patches(session, patches)
# GPU: threshold
binary = postprocess_gpu(output)
# GPU: reconstruct
result_gpu = reconstruct_gpu(binary, positions, w, h)
# GPU -> CPU
result = result_gpu.get()
# CPU: save
Image.fromarray(result, "L").convert("1").save(
output_path, format="TIFF", compression="group4", dpi=(300, 300)
)
print(f"Saved {output_path}")
# Clean up GPU memory
del img_gpu, patches, output, binary, result_gpu
if __name__ == "__main__":
if len(sys.argv) < 3:
print(f"Usage: {sys.argv[0]} <input.jpg> <output.tif>")
sys.exit(1)
binarize_image(sys.argv[1], sys.argv[2])