"""Tests for ``captioning.preprocessing.tokenizer.CaptionTokenizer``. These are TF-dependent and slow to import; pytest auto-skips if TF is missing. """ from __future__ import annotations from pathlib import Path import pytest tf = pytest.importorskip("tensorflow") from captioning.preprocessing.tokenizer import ( # noqa: E402 VOCAB_JSON_FILENAME, VOCAB_PICKLE_FILENAME, CaptionTokenizer, ) def test_fit_then_encode_decode_roundtrip(tiny_caption_corpus: list[str]) -> None: tok = CaptionTokenizer(vocab_size=200, max_length=20) tok.fit(tiny_caption_corpus) ids = tok.encode([tiny_caption_corpus[0]]) assert ids.shape == (1, 20) # Decoding the first non-padding id should produce a known token. first_id = int(ids[0, 0].numpy()) word = tok.decode_id(first_id) assert isinstance(word, str) def test_save_load_round_trip_matches_original( tiny_caption_corpus: list[str], tmp_artifacts_dir: Path ) -> None: tok = CaptionTokenizer(vocab_size=200, max_length=20) tok.fit(tiny_caption_corpus) tok.save(tmp_artifacts_dir) assert (tmp_artifacts_dir / VOCAB_PICKLE_FILENAME).is_file() assert (tmp_artifacts_dir / VOCAB_JSON_FILENAME).is_file() loaded = CaptionTokenizer.load(tmp_artifacts_dir, vocab_size=200, max_length=20) assert loaded.vocabulary == tok.vocabulary # Encoding should match exactly ids_a = tok.encode([tiny_caption_corpus[0]]).numpy().tolist() ids_b = loaded.encode([tiny_caption_corpus[0]]).numpy().tolist() assert ids_a == ids_b def test_unfitted_tokenizer_raises(tmp_artifacts_dir: Path) -> None: tok = CaptionTokenizer(vocab_size=200, max_length=20) with pytest.raises(RuntimeError, match="not fitted"): _ = tok.vocabulary with pytest.raises(RuntimeError, match="not fitted"): tok.encode(["hello"]) with pytest.raises(RuntimeError, match="not fitted"): tok.save(tmp_artifacts_dir) def test_max_length_is_respected(tiny_caption_corpus: list[str]) -> None: tok = CaptionTokenizer(vocab_size=200, max_length=10) tok.fit(tiny_caption_corpus) long_caption = " ".join(["[start]"] + ["word"] * 30 + ["[end]"]) ids = tok.encode([long_caption]) assert ids.shape == (1, 10) def test_word_to_id_round_trips_with_decode(tiny_caption_corpus: list[str]) -> None: """``word_to_id`` is the inverse of ``decode_id`` for in-vocabulary tokens.""" tok = CaptionTokenizer(vocab_size=200, max_length=20) tok.fit(tiny_caption_corpus) start_id = tok.word_to_id("[start]") end_id = tok.word_to_id("[end]") assert isinstance(start_id, int) assert start_id != end_id assert tok.decode_id(start_id) == "[start]" assert tok.decode_id(end_id) == "[end]"