File size: 1,993 Bytes
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Token + positional embedding layer.

Mirrors notebook cell 18 verbatim. The decoder learns its own positional
encoding (rather than using sinusoidal) — that's the published architecture,
preserved here.
"""

from __future__ import annotations


def _import_tf():
    """Local import keeps top-level package import lightweight.

    Without this, ``from captioning.models import Embeddings`` would trigger
    a multi-second TF import even for callers that don't use it.
    """
    import tensorflow as tf

    return tf


# Defining the class lazily inside a factory keeps TF out of the import path.
# Callers do ``Embeddings = _build_embeddings_class()`` once at module init.
def _build_embeddings_class():
    tf = _import_tf()

    class Embeddings(tf.keras.layers.Layer):
        """Sum of token and learned positional embeddings.

        Args:
            vocab_size: Size of the token vocabulary
                (``CaptionTokenizer.vocabulary_size``).
            embed_dim: Dimensionality of each embedding vector
                (``model.embedding_dim``, default 512).
            max_len: Maximum sequence length (``model.max_length``, default 40).
        """

        def __init__(self, vocab_size: int, embed_dim: int, max_len: int) -> None:
            super().__init__()
            self.token_embeddings = tf.keras.layers.Embedding(vocab_size, embed_dim)
            self.position_embeddings = tf.keras.layers.Embedding(
                max_len, embed_dim, input_shape=(None, max_len)
            )

        def call(self, input_ids):
            length = tf.shape(input_ids)[-1]
            position_ids = tf.range(start=0, limit=length, delta=1)
            position_ids = tf.expand_dims(position_ids, axis=0)
            token_embeddings = self.token_embeddings(input_ids)
            position_embeddings = self.position_embeddings(position_ids)
            return token_embeddings + position_embeddings

    return Embeddings


Embeddings = _build_embeddings_class()