File size: 1,930 Bytes
1b5db69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
Minimal example: binarize a document image using the SBB ONNX model.

    pip install onnxruntime-gpu numpy Pillow
    python3 sample_workflow.py input.jpg output.tif
"""

import sys
import numpy as np
from PIL import Image
import onnxruntime as ort

MODEL = "model_convtranspose.onnx"
PATCH = 448

# Load model
sess = ort.InferenceSession(MODEL, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

# Load image
img = np.array(Image.open(sys.argv[1]).convert("RGB"))
h, w = img.shape[:2]

# Extract 448x448 patches (the model requires fixed-size input)
patches, positions = [], []
for y in range(0, h, PATCH):
    for x in range(0, w, PATCH):
        patch = np.zeros((PATCH, PATCH, 3), dtype=np.uint8)
        ph, pw = min(PATCH, h - y), min(PATCH, w - x)
        patch[:ph, :pw] = img[y:y+ph, x:x+pw]
        patches.append(patch)
        positions.append((x, y))

# Normalize (matches original TF model's float64->float32 rounding)
lut = np.array([np.float32(np.float64(i) / 255.0) for i in range(256)], dtype=np.float32)
patches_float = lut[np.array(patches).astype(np.int32)]

# Run inference in batches
outputs = []
for i in range(0, len(patches), 64):
    batch = patches_float[i:i+64]
    out = sess.run(["activation_55"], {"input_1": batch})[0]
    outputs.append(out)
output = np.concatenate(outputs)

# Threshold and reconstruct
result = np.zeros((h, w), dtype=np.float32)
weight = np.zeros((h, w), dtype=np.float32)
for i, (x, y) in enumerate(positions):
    prob = output[i, :, :, 1]
    binary = np.where((prob * 255).astype(np.uint8) <= 128, 255.0, 0.0)
    ah, aw = min(PATCH, h - y), min(PATCH, w - x)
    result[y:y+ah, x:x+aw] += binary[:ah, :aw]
    weight[y:y+ah, x:x+aw] += 1.0
result = (result / np.maximum(weight, 1)).astype(np.uint8)

# Save
Image.fromarray(result, "L").convert("1").save(
    sys.argv[2], format="TIFF", compression="group4", dpi=(300, 300)
)
print(f"Saved {sys.argv[2]}")