apoorvrajdev's picture
feat: finalize Phase 1 modular ML architecture
3a2e5f0
"""Token + positional embedding layer.
Mirrors notebook cell 18 verbatim. The decoder learns its own positional
encoding (rather than using sinusoidal) — that's the published architecture,
preserved here.
"""
from __future__ import annotations
def _import_tf():
"""Local import keeps top-level package import lightweight.
Without this, ``from captioning.models import Embeddings`` would trigger
a multi-second TF import even for callers that don't use it.
"""
import tensorflow as tf
return tf
# Defining the class lazily inside a factory keeps TF out of the import path.
# Callers do ``Embeddings = _build_embeddings_class()`` once at module init.
def _build_embeddings_class():
tf = _import_tf()
class Embeddings(tf.keras.layers.Layer):
"""Sum of token and learned positional embeddings.
Args:
vocab_size: Size of the token vocabulary
(``CaptionTokenizer.vocabulary_size``).
embed_dim: Dimensionality of each embedding vector
(``model.embedding_dim``, default 512).
max_len: Maximum sequence length (``model.max_length``, default 40).
"""
def __init__(self, vocab_size: int, embed_dim: int, max_len: int) -> None:
super().__init__()
self.token_embeddings = tf.keras.layers.Embedding(vocab_size, embed_dim)
self.position_embeddings = tf.keras.layers.Embedding(
max_len, embed_dim, input_shape=(None, max_len)
)
def call(self, input_ids):
length = tf.shape(input_ids)[-1]
position_ids = tf.range(start=0, limit=length, delta=1)
position_ids = tf.expand_dims(position_ids, axis=0)
token_embeddings = self.token_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
return token_embeddings + position_embeddings
return Embeddings
Embeddings = _build_embeddings_class()