Spaces:
Runtime error
Runtime error
| import functools | |
| import torch | |
| from torch import nn | |
| class WordAndPositionalEmbedding(nn.Module): | |
| r""" | |
| A :class:`~torch.nn.Module` for learned word embeddings and position | |
| embeddings for input tokens. Each token is mapped to a fixed dimensional | |
| word embedding; and corresponding positional embedding based on its index. | |
| These are summed together followed by layer normalization and an optional | |
| dropout. | |
| Parameters | |
| ---------- | |
| vocab_size: int | |
| Size of token vocabulary. | |
| hidden_size: int | |
| Size of token embedding vectors. | |
| dropout: float, optional (default = 0.1) | |
| Dropout probability for final dropout applied after layer normalization. | |
| max_caption_length: int, optional (default = 30) | |
| Maximum length of input captions; this is used to create a fixed | |
| positional embedding lookup table. | |
| padding_idx: int, optional (default = 0) | |
| Token index of ``[PAD]`` token, word embedding for these tokens will | |
| be a vector of zeroes (and not trainable). | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| hidden_size: int, | |
| dropout: float = 0.0, | |
| max_caption_length: int = 30, | |
| padding_idx: int = 0, | |
| ): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.padding_idx = padding_idx | |
| self.words = nn.Embedding(vocab_size, hidden_size, padding_idx=padding_idx) | |
| # We provide no "padding index" for positional embeddings. We zero out | |
| # the positional embeddings of padded positions as a post-processing. | |
| self.positions = nn.Embedding(max_caption_length, hidden_size) | |
| self.layer_norm = nn.LayerNorm( | |
| hidden_size, eps=1e-8, elementwise_affine=True | |
| ) | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, tokens: torch.Tensor) -> torch.Tensor: | |
| r""" | |
| Get combined word and positional embeddings for input tokens. | |
| Parameters | |
| ---------- | |
| tokens: torch.Tensor | |
| A tensor of shape ``(batch_size, max_caption_length)`` containing | |
| a batch of caption tokens, with values in ``[0, vocab_size)``. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| A tensor of shape ``(batch_size, max_caption_length, hidden_size)`` | |
| containing corresponding token embeddings. | |
| """ | |
| position_indices = self._create_position_indices(tokens) | |
| # shape: (batch_size, max_caption_length, hidden_size) | |
| word_embeddings = self.words(tokens) | |
| position_embeddings = self.positions(position_indices) | |
| # shape: (batch_size, max_caption_length, hidden_size) | |
| embeddings = self.layer_norm(word_embeddings + position_embeddings) | |
| embeddings = self.dropout(embeddings) | |
| # Zero-out embeddings for positions which have padding tokens. | |
| # shape: (batch_size, max_caption_length, 1) | |
| token_mask = (tokens != self.padding_idx).unsqueeze(-1) | |
| # shape: (batch_size, max_caption_length, hidden_size) | |
| embeddings = embeddings * token_mask.type(embeddings.dtype) | |
| return embeddings | |
| def _create_position_indices(self, tokens: torch.Tensor): | |
| # Create position indices of the same size as token indices. | |
| batch_size, max_caption_length = tokens.size() | |
| positions = torch.arange( | |
| max_caption_length, dtype=tokens.dtype, device=tokens.device | |
| ) | |
| # shape: (batch_size, max_caption_length) | |
| positions = positions.unsqueeze(0).expand(batch_size, max_caption_length) | |
| return positions | |