Spaces:
Configuration error
Configuration error
File size: 1,993 Bytes
3a2e5f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | """Token + positional embedding layer.
Mirrors notebook cell 18 verbatim. The decoder learns its own positional
encoding (rather than using sinusoidal) — that's the published architecture,
preserved here.
"""
from __future__ import annotations
def _import_tf():
"""Local import keeps top-level package import lightweight.
Without this, ``from captioning.models import Embeddings`` would trigger
a multi-second TF import even for callers that don't use it.
"""
import tensorflow as tf
return tf
# Defining the class lazily inside a factory keeps TF out of the import path.
# Callers do ``Embeddings = _build_embeddings_class()`` once at module init.
def _build_embeddings_class():
tf = _import_tf()
class Embeddings(tf.keras.layers.Layer):
"""Sum of token and learned positional embeddings.
Args:
vocab_size: Size of the token vocabulary
(``CaptionTokenizer.vocabulary_size``).
embed_dim: Dimensionality of each embedding vector
(``model.embedding_dim``, default 512).
max_len: Maximum sequence length (``model.max_length``, default 40).
"""
def __init__(self, vocab_size: int, embed_dim: int, max_len: int) -> None:
super().__init__()
self.token_embeddings = tf.keras.layers.Embedding(vocab_size, embed_dim)
self.position_embeddings = tf.keras.layers.Embedding(
max_len, embed_dim, input_shape=(None, max_len)
)
def call(self, input_ids):
length = tf.shape(input_ids)[-1]
position_ids = tf.range(start=0, limit=length, delta=1)
position_ids = tf.expand_dims(position_ids, axis=0)
token_embeddings = self.token_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
return token_embeddings + position_embeddings
return Embeddings
Embeddings = _build_embeddings_class()
|