""" Full GPU pipeline example for the SBB binarization ONNX model. Shows how to keep the entire image processing chain on the GPU using CuPy, with only the JPEG decode and TIFF save happening on CPU. This is the approach you'd use for a production pipeline where throughput matters. pip install onnxruntime-gpu cupy-cuda12x numpy Pillow python3 example_gpu_pipeline.py input.jpg output.tif On first run, TensorRT builds an optimized engine (~60-90s). This is cached in ./trt_cache/ and reused on subsequent runs. """ import sys import os import numpy as np import cupy as cp import onnxruntime as ort from PIL import Image MODEL = "model_convtranspose.onnx" PATCH_SIZE = 448 BATCH_SIZE = 64 # ── Normalization LUT ──────────────────────────────────────────────────────── # The original TF model normalizes with: np.array(img) / 255.0 which does # float64 division then truncates to float32. Doing float32 division directly # gives different rounding for some values (off by 1 ULP), which can flip # pixels at the binarization threshold. This LUT preserves the exact behavior. _NORM_LUT = cp.array( np.array([np.float32(np.float64(i) / 255.0) for i in range(256)], dtype=np.float32) ) def create_session(model_path): """Create an ONNX Runtime session with TensorRT backend. TensorRT compiles the model into an optimized GPU engine on first run. The engine is cached to disk so subsequent runs start in ~2 seconds. """ cache_dir = "./trt_cache" os.makedirs(cache_dir, exist_ok=True) return ort.InferenceSession(model_path, providers=[ ("TensorrtExecutionProvider", { "device_id": 0, "trt_fp16_enable": False, # FP32 for accuracy "trt_engine_cache_enable": True, "trt_engine_cache_path": cache_dir, "trt_builder_optimization_level": 3, }), ("CUDAExecutionProvider", {"device_id": 0}), ]) def extract_patches_gpu(img_gpu, patch_size): """Extract non-overlapping patches on GPU. Zero-pads edges.""" h, w = img_gpu.shape[:2] positions = [(x, y) for y in range(0, h, patch_size) for x in range(0, w, patch_size)] patches = cp.zeros((len(positions), patch_size, patch_size, 3), dtype=cp.uint8) for i, (x, y) in enumerate(positions): ph = min(patch_size, h - y) pw = min(patch_size, w - x) patches[i, :ph, :pw, :] = img_gpu[y:y+ph, x:x+pw, :] return patches, positions def infer_patches(session, patches_uint8): """Normalize and run inference, one batch at a time. Normalizing per-batch (64 patches = 154MB) instead of all at once (500+ patches = 2.6GB) avoids GPU memory fragmentation. """ inp = session.get_inputs()[0].name out = session.get_outputs()[0].name n = patches_uint8.shape[0] out_ch = session.get_outputs()[0].shape[3] or 2 all_output = cp.zeros((n, PATCH_SIZE, PATCH_SIZE, out_ch), dtype=cp.float32) for i in range(0, n, BATCH_SIZE): end = min(i + BATCH_SIZE, n) # Normalize on GPU via LUT (uint8 -> float32, 8ms per batch) batch_float = _NORM_LUT[patches_uint8[i:end].astype(cp.int32)] # Transfer to CPU for ONNX Runtime inference result = session.run([out], {inp: batch_float.get()})[0] # Transfer result back to GPU for post-processing all_output[i:end] = cp.asarray(result) return all_output def postprocess_gpu(output): """Extract foreground probability, threshold, binarize — all on GPU.""" probs = output[:, :, :, 1] # channel 1 = foreground quantized = (probs * 255.0).astype(cp.uint8) return cp.where(quantized <= 128, cp.uint8(255), cp.uint8(0)) def reconstruct_gpu(patches, positions, width, height): """Reconstruct full image from patches with overlap averaging — all on GPU.""" result = cp.zeros((height, width), dtype=cp.float32) weight = cp.zeros((height, width), dtype=cp.float32) for i, (x, y) in enumerate(positions): ah = min(PATCH_SIZE, height - y) aw = min(PATCH_SIZE, width - x) result[y:y+ah, x:x+aw] += patches[i, :ah, :aw].astype(cp.float32) weight[y:y+ah, x:x+aw] += 1.0 return (result / cp.maximum(weight, 1.0)).astype(cp.uint8) def binarize_image(input_path, output_path, model_path=MODEL): """Full pipeline: JPEG in -> binarized TIFF out. Data flow: CPU: decode JPEG CPU -> GPU: upload image (~5ms) GPU: extract patches (~7ms) GPU -> CPU -> GPU: normalize, infer, collect (~175ms per batch) GPU: threshold + binarize (~1ms) GPU: reconstruct from patches (~13ms) GPU -> CPU: download result (~2ms) CPU: save Group4 TIFF """ # CPU: decode img = np.array(Image.open(input_path).convert("RGB")) h, w = img.shape[:2] # CPU -> GPU img_gpu = cp.asarray(img) # GPU: patch extraction patches, positions = extract_patches_gpu(img_gpu, PATCH_SIZE) # Inference (normalize on GPU, infer via ORT, results back on GPU) session = create_session(model_path) output = infer_patches(session, patches) # GPU: threshold binary = postprocess_gpu(output) # GPU: reconstruct result_gpu = reconstruct_gpu(binary, positions, w, h) # GPU -> CPU result = result_gpu.get() # CPU: save Image.fromarray(result, "L").convert("1").save( output_path, format="TIFF", compression="group4", dpi=(300, 300) ) print(f"Saved {output_path}") # Clean up GPU memory del img_gpu, patches, output, binary, result_gpu if __name__ == "__main__": if len(sys.argv) < 3: print(f"Usage: {sys.argv[0]} ") sys.exit(1) binarize_image(sys.argv[1], sys.argv[2])