Spaces:
Configuration error
Configuration error
File size: 2,746 Bytes
3a2e5f0 91a1214 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | """Tests for ``captioning.preprocessing.tokenizer.CaptionTokenizer``.
These are TF-dependent and slow to import; pytest auto-skips if TF is missing.
"""
from __future__ import annotations
from pathlib import Path
import pytest
tf = pytest.importorskip("tensorflow")
from captioning.preprocessing.tokenizer import ( # noqa: E402
VOCAB_JSON_FILENAME,
VOCAB_PICKLE_FILENAME,
CaptionTokenizer,
)
def test_fit_then_encode_decode_roundtrip(tiny_caption_corpus: list[str]) -> None:
tok = CaptionTokenizer(vocab_size=200, max_length=20)
tok.fit(tiny_caption_corpus)
ids = tok.encode([tiny_caption_corpus[0]])
assert ids.shape == (1, 20)
# Decoding the first non-padding id should produce a known token.
first_id = int(ids[0, 0].numpy())
word = tok.decode_id(first_id)
assert isinstance(word, str)
def test_save_load_round_trip_matches_original(
tiny_caption_corpus: list[str], tmp_artifacts_dir: Path
) -> None:
tok = CaptionTokenizer(vocab_size=200, max_length=20)
tok.fit(tiny_caption_corpus)
tok.save(tmp_artifacts_dir)
assert (tmp_artifacts_dir / VOCAB_PICKLE_FILENAME).is_file()
assert (tmp_artifacts_dir / VOCAB_JSON_FILENAME).is_file()
loaded = CaptionTokenizer.load(tmp_artifacts_dir, vocab_size=200, max_length=20)
assert loaded.vocabulary == tok.vocabulary
# Encoding should match exactly
ids_a = tok.encode([tiny_caption_corpus[0]]).numpy().tolist()
ids_b = loaded.encode([tiny_caption_corpus[0]]).numpy().tolist()
assert ids_a == ids_b
def test_unfitted_tokenizer_raises(tmp_artifacts_dir: Path) -> None:
tok = CaptionTokenizer(vocab_size=200, max_length=20)
with pytest.raises(RuntimeError, match="not fitted"):
_ = tok.vocabulary
with pytest.raises(RuntimeError, match="not fitted"):
tok.encode(["hello"])
with pytest.raises(RuntimeError, match="not fitted"):
tok.save(tmp_artifacts_dir)
def test_max_length_is_respected(tiny_caption_corpus: list[str]) -> None:
tok = CaptionTokenizer(vocab_size=200, max_length=10)
tok.fit(tiny_caption_corpus)
long_caption = " ".join(["[start]"] + ["word"] * 30 + ["[end]"])
ids = tok.encode([long_caption])
assert ids.shape == (1, 10)
def test_word_to_id_round_trips_with_decode(tiny_caption_corpus: list[str]) -> None:
"""``word_to_id`` is the inverse of ``decode_id`` for in-vocabulary tokens."""
tok = CaptionTokenizer(vocab_size=200, max_length=20)
tok.fit(tiny_caption_corpus)
start_id = tok.word_to_id("[start]")
end_id = tok.word_to_id("[end]")
assert isinstance(start_id, int)
assert start_id != end_id
assert tok.decode_id(start_id) == "[start]"
assert tok.decode_id(end_id) == "[end]"
|