File size: 1,065 Bytes
b5b608e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | from __future__ import annotations
import numpy as np
from PIL import Image
from processing_htr import HTRProcessor
def test_preprocess_shape_stride_alignment() -> None:
processor = HTRProcessor(characters=["а", "б", "в"], image_height=64, image_max_width=256, width_stride=32)
img = Image.fromarray((np.random.rand(53, 117) * 255).astype(np.uint8))
out = processor(images=img, return_tensors="np")["pixel_values"]
assert out.shape[0] == 1
assert out.shape[1] == 1
assert out.shape[2] == 64
assert out.shape[3] % 32 == 0
def test_ctc_batch_decode() -> None:
processor = HTRProcessor(characters=["а", "б", "в"], image_height=64, image_max_width=256, width_stride=32)
logits = np.array(
[
[[0.1, 0.9, 0.0, 0.0], [0.1, 0.0, 0.9, 0.0]],
[[0.1, 0.9, 0.0, 0.0], [0.1, 0.0, 0.9, 0.0]],
[[0.9, 0.0, 0.0, 0.0], [0.9, 0.0, 0.0, 0.0]],
],
dtype=np.float32,
) # T,N,C
texts = processor.batch_decode(logits, logit_layout="tnc")
assert texts == ["а", "б"]
|