image-captioning-api / tests /unit /test_tokenizer.py
apoorvrajdev's picture
feat(evaluation): add beam search, metrics pipeline, and stabilized training workflow
91a1214
"""Tests for ``captioning.preprocessing.tokenizer.CaptionTokenizer``.
These are TF-dependent and slow to import; pytest auto-skips if TF is missing.
"""
from __future__ import annotations
from pathlib import Path
import pytest
tf = pytest.importorskip("tensorflow")
from captioning.preprocessing.tokenizer import ( # noqa: E402
VOCAB_JSON_FILENAME,
VOCAB_PICKLE_FILENAME,
CaptionTokenizer,
)
def test_fit_then_encode_decode_roundtrip(tiny_caption_corpus: list[str]) -> None:
tok = CaptionTokenizer(vocab_size=200, max_length=20)
tok.fit(tiny_caption_corpus)
ids = tok.encode([tiny_caption_corpus[0]])
assert ids.shape == (1, 20)
# Decoding the first non-padding id should produce a known token.
first_id = int(ids[0, 0].numpy())
word = tok.decode_id(first_id)
assert isinstance(word, str)
def test_save_load_round_trip_matches_original(
tiny_caption_corpus: list[str], tmp_artifacts_dir: Path
) -> None:
tok = CaptionTokenizer(vocab_size=200, max_length=20)
tok.fit(tiny_caption_corpus)
tok.save(tmp_artifacts_dir)
assert (tmp_artifacts_dir / VOCAB_PICKLE_FILENAME).is_file()
assert (tmp_artifacts_dir / VOCAB_JSON_FILENAME).is_file()
loaded = CaptionTokenizer.load(tmp_artifacts_dir, vocab_size=200, max_length=20)
assert loaded.vocabulary == tok.vocabulary
# Encoding should match exactly
ids_a = tok.encode([tiny_caption_corpus[0]]).numpy().tolist()
ids_b = loaded.encode([tiny_caption_corpus[0]]).numpy().tolist()
assert ids_a == ids_b
def test_unfitted_tokenizer_raises(tmp_artifacts_dir: Path) -> None:
tok = CaptionTokenizer(vocab_size=200, max_length=20)
with pytest.raises(RuntimeError, match="not fitted"):
_ = tok.vocabulary
with pytest.raises(RuntimeError, match="not fitted"):
tok.encode(["hello"])
with pytest.raises(RuntimeError, match="not fitted"):
tok.save(tmp_artifacts_dir)
def test_max_length_is_respected(tiny_caption_corpus: list[str]) -> None:
tok = CaptionTokenizer(vocab_size=200, max_length=10)
tok.fit(tiny_caption_corpus)
long_caption = " ".join(["[start]"] + ["word"] * 30 + ["[end]"])
ids = tok.encode([long_caption])
assert ids.shape == (1, 10)
def test_word_to_id_round_trips_with_decode(tiny_caption_corpus: list[str]) -> None:
"""``word_to_id`` is the inverse of ``decode_id`` for in-vocabulary tokens."""
tok = CaptionTokenizer(vocab_size=200, max_length=20)
tok.fit(tiny_caption_corpus)
start_id = tok.word_to_id("[start]")
end_id = tok.word_to_id("[end]")
assert isinstance(start_id, int)
assert start_id != end_id
assert tok.decode_id(start_id) == "[start]"
assert tok.decode_id(end_id) == "[end]"