File size: 3,150 Bytes
1d1e600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""``tf.data`` pipeline construction.

Mirrors notebook cell 13 (the ``load_data`` map function) and cell 14 (the
two pipeline definitions). The notebook closes over a global ``tokenizer``;
we pass the tokenizer in explicitly so the same code works in tests, scripts,
and the parity audit.

Note on ``shuffle`` in the val pipeline: the notebook shuffles both train
and val. That is technically unnecessary for validation but harmless, and we
preserve it for parity. Phase 1b removes it from val.
"""

from __future__ import annotations

from collections.abc import Sequence

from captioning.preprocessing.image import preprocess_image_tensor
from captioning.preprocessing.tokenizer import CaptionTokenizer


def _make_load_data_fn(tokenizer: CaptionTokenizer):
    """Return a ``tf.data``-compatible map function (image_path, caption) -> (image, ids).

    Defined as a closure rather than a top-level function so it captures the
    tokenizer without leaking it into the module namespace. ``tf.data`` calls
    this for every example with both arguments as ``tf.string`` tensors.
    """
    import tensorflow as tf

    def load_data(image_path, caption):
        raw = tf.io.read_file(image_path)
        image = tf.io.decode_jpeg(raw, channels=3)
        image = preprocess_image_tensor(image)
        ids = tokenizer.encode(caption)
        return image, ids

    return load_data


def build_train_pipeline(
    image_paths: Sequence[str],
    captions: Sequence[str],
    tokenizer: CaptionTokenizer,
    batch_size: int,
    buffer_size: int,
):
    """Build the training dataset, byte-identical to notebook cell 14.

    Args:
        image_paths: One path per (image, caption) pair (image-level split
            already applied — see ``data.splits``).
        captions: Preprocessed captions, one per ``image_paths`` entry.
        tokenizer: Fitted ``CaptionTokenizer``.
        batch_size: Mini-batch size; matches ``BATCH_SIZE`` in the notebook.
        buffer_size: Shuffle buffer size; matches ``BUFFER_SIZE``.

    Returns:
        A ``tf.data.Dataset`` yielding ``(image, token_ids)`` batches.
    """
    import tensorflow as tf

    load_data = _make_load_data_fn(tokenizer)
    return (
        tf.data.Dataset.from_tensor_slices((list(image_paths), list(captions)))
        .map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
        .shuffle(buffer_size)
        .batch(batch_size)
    )


def build_val_pipeline(
    image_paths: Sequence[str],
    captions: Sequence[str],
    tokenizer: CaptionTokenizer,
    batch_size: int,
    buffer_size: int,
):
    """Build the validation dataset.

    Identical structure to ``build_train_pipeline``, with a separate function
    so Phase 1b can drop the (unnecessary) shuffle from val without coupling
    the change to train.
    """
    import tensorflow as tf

    load_data = _make_load_data_fn(tokenizer)
    return (
        tf.data.Dataset.from_tensor_slices((list(image_paths), list(captions)))
        .map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
        .shuffle(buffer_size)  # Notebook cell 14 shuffles val too — preserved.
        .batch(batch_size)
    )