File size: 631 Bytes
bfc4d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch


class DualCodebookEmbedding(torch.nn.Module):
    def __init__(self, 

                 vocab_size: int,

                 input_size: int,

                 ):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, input_size // 2)
        
    def forward(self, token: torch.Tensor):
        """

        Args:

            token (torch.Tensor): shape (b, t, 2)

        Returns:

            xs: shape (b, t, c)

        """
        embed1 = self.embedding(token[..., 0])
        embed2 = self.embedding(token[..., 1])
        return torch.cat([embed1, embed2], dim=-1)