File size: 5,912 Bytes
d8cc6e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Full GPU pipeline example for the SBB binarization ONNX model.

Shows how to keep the entire image processing chain on the GPU using CuPy,
with only the JPEG decode and TIFF save happening on CPU. This is the
approach you'd use for a production pipeline where throughput matters.

    pip install onnxruntime-gpu cupy-cuda12x numpy Pillow
    python3 example_gpu_pipeline.py input.jpg output.tif

On first run, TensorRT builds an optimized engine (~60-90s). This is
cached in ./trt_cache/ and reused on subsequent runs.
"""

import sys
import os
import numpy as np
import cupy as cp
import onnxruntime as ort
from PIL import Image

MODEL = "model_convtranspose.onnx"
PATCH_SIZE = 448
BATCH_SIZE = 64

# ── Normalization LUT ────────────────────────────────────────────────────────
# The original TF model normalizes with: np.array(img) / 255.0 which does
# float64 division then truncates to float32. Doing float32 division directly
# gives different rounding for some values (off by 1 ULP), which can flip
# pixels at the binarization threshold. This LUT preserves the exact behavior.
_NORM_LUT = cp.array(
    np.array([np.float32(np.float64(i) / 255.0) for i in range(256)],
             dtype=np.float32)
)


def create_session(model_path):
    """Create an ONNX Runtime session with TensorRT backend.

    TensorRT compiles the model into an optimized GPU engine on first run.
    The engine is cached to disk so subsequent runs start in ~2 seconds.
    """
    cache_dir = "./trt_cache"
    os.makedirs(cache_dir, exist_ok=True)
    return ort.InferenceSession(model_path, providers=[
        ("TensorrtExecutionProvider", {
            "device_id": 0,
            "trt_fp16_enable": False,          # FP32 for accuracy
            "trt_engine_cache_enable": True,
            "trt_engine_cache_path": cache_dir,
            "trt_builder_optimization_level": 3,
        }),
        ("CUDAExecutionProvider", {"device_id": 0}),
    ])


def extract_patches_gpu(img_gpu, patch_size):
    """Extract non-overlapping patches on GPU. Zero-pads edges."""
    h, w = img_gpu.shape[:2]
    positions = [(x, y) for y in range(0, h, patch_size)
                        for x in range(0, w, patch_size)]

    patches = cp.zeros((len(positions), patch_size, patch_size, 3), dtype=cp.uint8)
    for i, (x, y) in enumerate(positions):
        ph = min(patch_size, h - y)
        pw = min(patch_size, w - x)
        patches[i, :ph, :pw, :] = img_gpu[y:y+ph, x:x+pw, :]

    return patches, positions


def infer_patches(session, patches_uint8):
    """Normalize and run inference, one batch at a time.

    Normalizing per-batch (64 patches = 154MB) instead of all at once
    (500+ patches = 2.6GB) avoids GPU memory fragmentation.
    """
    inp = session.get_inputs()[0].name
    out = session.get_outputs()[0].name
    n = patches_uint8.shape[0]
    out_ch = session.get_outputs()[0].shape[3] or 2

    all_output = cp.zeros((n, PATCH_SIZE, PATCH_SIZE, out_ch), dtype=cp.float32)

    for i in range(0, n, BATCH_SIZE):
        end = min(i + BATCH_SIZE, n)

        # Normalize on GPU via LUT (uint8 -> float32, 8ms per batch)
        batch_float = _NORM_LUT[patches_uint8[i:end].astype(cp.int32)]

        # Transfer to CPU for ONNX Runtime inference
        result = session.run([out], {inp: batch_float.get()})[0]

        # Transfer result back to GPU for post-processing
        all_output[i:end] = cp.asarray(result)

    return all_output


def postprocess_gpu(output):
    """Extract foreground probability, threshold, binarize β€” all on GPU."""
    probs = output[:, :, :, 1]  # channel 1 = foreground
    quantized = (probs * 255.0).astype(cp.uint8)
    return cp.where(quantized <= 128, cp.uint8(255), cp.uint8(0))


def reconstruct_gpu(patches, positions, width, height):
    """Reconstruct full image from patches with overlap averaging β€” all on GPU."""
    result = cp.zeros((height, width), dtype=cp.float32)
    weight = cp.zeros((height, width), dtype=cp.float32)

    for i, (x, y) in enumerate(positions):
        ah = min(PATCH_SIZE, height - y)
        aw = min(PATCH_SIZE, width - x)
        result[y:y+ah, x:x+aw] += patches[i, :ah, :aw].astype(cp.float32)
        weight[y:y+ah, x:x+aw] += 1.0

    return (result / cp.maximum(weight, 1.0)).astype(cp.uint8)


def binarize_image(input_path, output_path, model_path=MODEL):
    """Full pipeline: JPEG in -> binarized TIFF out.

    Data flow:
        CPU: decode JPEG
        CPU -> GPU: upload image (~5ms)
        GPU: extract patches (~7ms)
        GPU -> CPU -> GPU: normalize, infer, collect (~175ms per batch)
        GPU: threshold + binarize (~1ms)
        GPU: reconstruct from patches (~13ms)
        GPU -> CPU: download result (~2ms)
        CPU: save Group4 TIFF
    """
    # CPU: decode
    img = np.array(Image.open(input_path).convert("RGB"))
    h, w = img.shape[:2]

    # CPU -> GPU
    img_gpu = cp.asarray(img)

    # GPU: patch extraction
    patches, positions = extract_patches_gpu(img_gpu, PATCH_SIZE)

    # Inference (normalize on GPU, infer via ORT, results back on GPU)
    session = create_session(model_path)
    output = infer_patches(session, patches)

    # GPU: threshold
    binary = postprocess_gpu(output)

    # GPU: reconstruct
    result_gpu = reconstruct_gpu(binary, positions, w, h)

    # GPU -> CPU
    result = result_gpu.get()

    # CPU: save
    Image.fromarray(result, "L").convert("1").save(
        output_path, format="TIFF", compression="group4", dpi=(300, 300)
    )
    print(f"Saved {output_path}")

    # Clean up GPU memory
    del img_gpu, patches, output, binary, result_gpu


if __name__ == "__main__":
    if len(sys.argv) < 3:
        print(f"Usage: {sys.argv[0]} <input.jpg> <output.tif>")
        sys.exit(1)
    binarize_image(sys.argv[1], sys.argv[2])