File size: 2,746 Bytes
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Tests for ``captioning.preprocessing.tokenizer.CaptionTokenizer``.

These are TF-dependent and slow to import; pytest auto-skips if TF is missing.
"""

from __future__ import annotations

from pathlib import Path

import pytest

tf = pytest.importorskip("tensorflow")

from captioning.preprocessing.tokenizer import (  # noqa: E402
    VOCAB_JSON_FILENAME,
    VOCAB_PICKLE_FILENAME,
    CaptionTokenizer,
)


def test_fit_then_encode_decode_roundtrip(tiny_caption_corpus: list[str]) -> None:
    tok = CaptionTokenizer(vocab_size=200, max_length=20)
    tok.fit(tiny_caption_corpus)

    ids = tok.encode([tiny_caption_corpus[0]])
    assert ids.shape == (1, 20)

    # Decoding the first non-padding id should produce a known token.
    first_id = int(ids[0, 0].numpy())
    word = tok.decode_id(first_id)
    assert isinstance(word, str)


def test_save_load_round_trip_matches_original(
    tiny_caption_corpus: list[str], tmp_artifacts_dir: Path
) -> None:
    tok = CaptionTokenizer(vocab_size=200, max_length=20)
    tok.fit(tiny_caption_corpus)
    tok.save(tmp_artifacts_dir)

    assert (tmp_artifacts_dir / VOCAB_PICKLE_FILENAME).is_file()
    assert (tmp_artifacts_dir / VOCAB_JSON_FILENAME).is_file()

    loaded = CaptionTokenizer.load(tmp_artifacts_dir, vocab_size=200, max_length=20)
    assert loaded.vocabulary == tok.vocabulary
    # Encoding should match exactly
    ids_a = tok.encode([tiny_caption_corpus[0]]).numpy().tolist()
    ids_b = loaded.encode([tiny_caption_corpus[0]]).numpy().tolist()
    assert ids_a == ids_b


def test_unfitted_tokenizer_raises(tmp_artifacts_dir: Path) -> None:
    tok = CaptionTokenizer(vocab_size=200, max_length=20)
    with pytest.raises(RuntimeError, match="not fitted"):
        _ = tok.vocabulary
    with pytest.raises(RuntimeError, match="not fitted"):
        tok.encode(["hello"])
    with pytest.raises(RuntimeError, match="not fitted"):
        tok.save(tmp_artifacts_dir)


def test_max_length_is_respected(tiny_caption_corpus: list[str]) -> None:
    tok = CaptionTokenizer(vocab_size=200, max_length=10)
    tok.fit(tiny_caption_corpus)
    long_caption = " ".join(["[start]"] + ["word"] * 30 + ["[end]"])
    ids = tok.encode([long_caption])
    assert ids.shape == (1, 10)


def test_word_to_id_round_trips_with_decode(tiny_caption_corpus: list[str]) -> None:
    """``word_to_id`` is the inverse of ``decode_id`` for in-vocabulary tokens."""
    tok = CaptionTokenizer(vocab_size=200, max_length=20)
    tok.fit(tiny_caption_corpus)
    start_id = tok.word_to_id("[start]")
    end_id = tok.word_to_id("[end]")
    assert isinstance(start_id, int)
    assert start_id != end_id
    assert tok.decode_id(start_id) == "[start]"
    assert tok.decode_id(end_id) == "[end]"