from __future__ import annotations import numpy as np from PIL import Image from processing_htr import HTRProcessor def test_preprocess_shape_stride_alignment() -> None: processor = HTRProcessor(characters=["а", "б", "в"], image_height=64, image_max_width=256, width_stride=32) img = Image.fromarray((np.random.rand(53, 117) * 255).astype(np.uint8)) out = processor(images=img, return_tensors="np")["pixel_values"] assert out.shape[0] == 1 assert out.shape[1] == 1 assert out.shape[2] == 64 assert out.shape[3] % 32 == 0 def test_ctc_batch_decode() -> None: processor = HTRProcessor(characters=["а", "б", "в"], image_height=64, image_max_width=256, width_stride=32) logits = np.array( [ [[0.1, 0.9, 0.0, 0.0], [0.1, 0.0, 0.9, 0.0]], [[0.1, 0.9, 0.0, 0.0], [0.1, 0.0, 0.9, 0.0]], [[0.9, 0.0, 0.0, 0.0], [0.9, 0.0, 0.0, 0.0]], ], dtype=np.float32, ) # T,N,C texts = processor.batch_decode(logits, logit_layout="tnc") assert texts == ["а", "б"]