File size: 8,796 Bytes
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08f1adc
3a2e5f0
 
 
08f1adc
 
3a2e5f0
08f1adc
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""``CaptionTokenizer`` — typed wrapper around ``tf.keras.layers.TextVectorization``.

Why a wrapper instead of using the Keras layer directly?

1. **Stable interface for the model.** The model code calls
   ``tokenizer.encode(captions)`` and ``tokenizer.decode_id(idx)``. The fact
   that those happen to delegate to a Keras layer is an implementation
   detail. In Phase 5 we may swap the implementation for HuggingFace
   ``tokenizers`` without rewriting the encoder, decoder, or inference loop.
2. **Persistence.** The notebook saves the *vocabulary list* with pickle, but
   loading requires re-instantiating a layer and calling ``set_vocabulary``.
   That ceremony belongs inside the wrapper, not at every call site.
3. **A JSON sidecar.** Pickle is fast but opaque and risky to load from
   untrusted sources. We additionally write a ``vocab.json`` file (one token
   per line, UTF-8) so humans and other tools can inspect the vocabulary.

The wrapper preserves the notebook's behaviour exactly: ``standardize=None``,
``output_sequence_length`` defaults to ``max_length``, and ``encode`` accepts
either a single string or a list of strings (matching the layer's call form
used in cells 7 and 25).
"""

from __future__ import annotations

import json
import pickle
from collections.abc import Iterable
from pathlib import Path

VOCAB_PICKLE_FILENAME = "vocab.pkl"
VOCAB_JSON_FILENAME = "vocab.json"


class CaptionTokenizer:
    """Wrapper that owns a fitted ``TextVectorization`` layer + lookup tables."""

    def __init__(self, vocab_size: int, max_length: int) -> None:
        """Construct an unfit tokenizer.

        Args:
            vocab_size: Maximum vocabulary size (notebook: ``VOCABULARY_SIZE``).
            max_length: Pad/truncate every caption to this many tokens
                (notebook: ``MAX_LENGTH``).
        """
        self.vocab_size = vocab_size
        self.max_length = max_length
        self._layer = None
        self._idx2word = None
        self._word2idx = None

    # ----------------------------------------------------------------- fit ----

    def fit(self, captions: Iterable[str]) -> None:
        """Adapt the underlying TextVectorization layer to the given captions.

        Args:
            captions: An iterable of *already preprocessed* captions
                (i.e. lower-cased, punctuation-stripped, wrapped in
                ``[start] ... [end]``). Mirrors notebook cell 7 which calls
                ``tokenizer.adapt(captions['caption'])`` *after* cell 4 has
                applied ``preprocess`` to every row.
        """
        import tensorflow as tf

        layer = tf.keras.layers.TextVectorization(
            max_tokens=self.vocab_size,
            standardize=None,
            output_sequence_length=self.max_length,
        )
        layer.adapt(list(captions))
        self._layer = layer
        self._build_lookups()

    # ----------------------------------------------------------- properties ---

    @property
    def vocabulary(self) -> list[str]:
        """Return the fitted vocabulary list (same order as TextVectorization)."""
        layer = self._require_fit()
        return list(layer.get_vocabulary())

    @property
    def vocabulary_size(self) -> int:
        """Number of tokens in the fitted vocabulary."""
        return int(self._require_fit().vocabulary_size())

    @property
    def layer(self):
        """Direct access to the inner Keras layer.

        Exposed because the model's ``Embeddings`` layer (notebook cell 19)
        needs ``tokenizer.vocabulary_size()`` at construction time. Phase 1b
        replaces this with a constructor argument and removes the property.
        """
        return self._require_fit()

    # -------------------------------------------------------- encode/decode ---

    def encode(self, text):
        """Encode ``text`` (str or list[str]) to integer-id tensor.

        Mirrors ``tokenizer(text)`` in notebook cells 7 and 25. Single string
        returns a 1-D tensor of shape ``[max_length]``; list returns 2-D.
        """
        return self._require_fit()(text)

    def decode_id(self, idx) -> str:
        """Inverse-lookup a single integer id to its string token.

        Mirrors notebook cell 25's
        ``idx2word(pred_idx).numpy().decode('utf-8')``.
        """
        self._require_fit()
        # By invariant, _idx2word is set together with _layer in fit/load.
        assert self._idx2word is not None
        word = self._idx2word(idx)
        return word.numpy().decode("utf-8")

    def word_to_id(self, word: str) -> int:
        """Look up a single word's integer id, returning 1 (the OOV id) if absent.

        Used by beam search to seed beams with the ``[start]`` token without
        going through ``TextVectorization``'s padded-string path.
        """
        self._require_fit()
        assert self._word2idx is not None
        return int(self._word2idx(word).numpy())

    # ---------------------------------------------------------- persistence ---

    def save(self, directory: str | Path) -> None:
        """Save the vocabulary to ``directory/vocab.pkl`` and ``vocab.json``.

        The pickle matches notebook cell 9 exactly so old artefacts remain
        loadable. The JSON sidecar is human-inspectable.
        """
        self._require_fit()
        directory = Path(directory)
        directory.mkdir(parents=True, exist_ok=True)
        vocab = self.vocabulary
        with (directory / VOCAB_PICKLE_FILENAME).open("wb") as f:
            pickle.dump(vocab, f)
        with (directory / VOCAB_JSON_FILENAME).open("w", encoding="utf-8") as f:
            json.dump(vocab, f, ensure_ascii=False, indent=2)

    @classmethod
    def load(
        cls,
        directory: str | Path,
        vocab_size: int,
        max_length: int,
    ) -> CaptionTokenizer:
        """Load a previously saved vocabulary into a new tokenizer.

        Args:
            directory: Directory containing ``vocab.pkl`` (or ``vocab.json``).
            vocab_size: Maximum vocabulary size — must match the saved vocab.
            max_length: Pad/truncate length — must match training-time value.

        Returns:
            A fitted ``CaptionTokenizer`` ready to ``encode`` and ``decode_id``.
        """
        import tensorflow as tf

        directory = Path(directory)
        pkl = directory / VOCAB_PICKLE_FILENAME
        js = directory / VOCAB_JSON_FILENAME
        if pkl.is_file():
            with pkl.open("rb") as f:
                vocab = pickle.load(f)
        elif js.is_file():
            with js.open(encoding="utf-8") as f:
                vocab = json.load(f)
        else:
            raise FileNotFoundError(
                f"No tokenizer vocabulary found in {directory!s}. "
                f"Expected '{VOCAB_PICKLE_FILENAME}' (preferred) or "
                f"'{VOCAB_JSON_FILENAME}'. Train the model with "
                "`python -m scripts.train --config configs/base.yaml` to "
                "produce the artefacts, or point BACKEND_TOKENIZER_DIR at a "
                "directory that contains them."
            )

        tok = cls(vocab_size=vocab_size, max_length=max_length)
        layer = tf.keras.layers.TextVectorization(
            max_tokens=vocab_size,
            standardize=None,
            output_sequence_length=max_length,
        )
        layer.set_vocabulary(vocab)
        tok._layer = layer
        tok._build_lookups()
        return tok

    # -------------------------------------------------------------- internal --

    def _build_lookups(self) -> None:
        """Construct ``StringLookup`` (idx → word) for inference decoding.

        Called only from ``fit()`` and ``load()``, *after* ``self._layer`` has
        been assigned, so the assertion below is a defensive no-op for mypy.
        """
        import tensorflow as tf

        assert self._layer is not None
        vocab = self._layer.get_vocabulary()
        self._word2idx = tf.keras.layers.StringLookup(mask_token="", vocabulary=vocab)
        self._idx2word = tf.keras.layers.StringLookup(mask_token="", vocabulary=vocab, invert=True)

    def _require_fit(self):
        """Validate that the tokenizer has been fitted; return the inner layer.

        Returning the layer (rather than only raising on the unfit state)
        gives callers a non-``None``-typed local for the rest of their body —
        which is what mypy needs to prove ``layer.get_vocabulary()`` etc.
        are valid calls. Costs one attribute lookup at runtime.
        """
        if self._layer is None:
            raise RuntimeError(
                "CaptionTokenizer not fitted. Call `.fit(captions)` or "
                "`.load(directory, ...)` first."
            )
        return self._layer