File size: 4,878 Bytes
5d2c747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
A collection of embedding models. A collection model includes
the tokenizer(s), token embeddings and positional encodings
(if necessary).
"""

import torch

from models.components.positional_encoding import build_positional_encodings


class EmbedderInterface(torch.nn.Module):
    """Interface for the embedder component of the model."""

    def __init__(self):
        super().__init__()
        self.eot_token = ...

    def forward(self, token_ids: torch.LongTensor):
        """This function should take the token_ids as input,

        and return the embeddings."""
        raise NotImplementedError

    def tokenize_input(self, input_string: str, truncate=False, add_eot=True):
        """This function should take a single input string and returns

        the tokenized input.
        Args:
            input_string: str
            truncate: bool - whether to perform (left) truncation
            add_eot: bool
        Returns:
            typically token_ids of shape (S,)
        """
        raise NotImplementedError

    def decode(self, tokens: torch.LongTensor):
        """This function should decode a tensor of tokens into a string.

        For the default implementation of get_sequence_info,
        we assume that the tokens are of shape (B, S) and we
        decode each sequence in the batch."""
        raise NotImplementedError

    def pad_batch(self, token_lists, direction="right"):
        """Pad a list of token lists to the same length,
        and return the padded tensor, and mask tensor."""
        raise NotImplementedError

    def truncate(self, token_lists):
        """Truncate a list of token lists, to be shorter than the,
        maximum length of the model and return the truncated tensor.
        """
        raise NotImplementedError

    def get_sequence_info(self, x):
        """
        Given a batch of sequences of tokens, return
        the character lengths.
        Args:
            x: torch.tensor(B, S)
        """

        sequence_char_lengths = []
        # then we decode everything
        # batch decode
        sequences = self.tokenizer.decode_batch(x)
        for seq in sequences:
            sequence_char_lengths.append(len(seq))

        # obtain the mask for end-of-word and pad tokens
        mask = x != self.tokenizer.pad_token
        mask = mask & (x != self.tokenizer.eot_token)

        return (
            sequence_char_lengths,
            mask,
        )


class Embedder(EmbedderInterface):
    """
    A simple and flexible embedding model.

    All embedders should inherit from this class.
    """

    def __init__(self, model_cfg, tokenizer):
        super().__init__()
        # build the tokenizer
        self.tokenizer = tokenizer
        assert self.tokenizer.vocab_size == model_cfg["vocab_size"], f"{model_cfg['vocab_size']=} must match {self.tokenizer.vocab_size=}"

        # build the token embeddings
        self.token_embedder = torch.nn.Embedding(
            num_embeddings=model_cfg["vocab_size"],
            embedding_dim=model_cfg["hidden_dim"],
        )

        # build the positional encodings
        self.positional_encodings = build_positional_encodings(model_cfg=model_cfg)
        self.eot_token = self.tokenizer.eot_token
        self.tokenizer.eos_token = self.eot_token
        self.model_cfg = model_cfg

    def forward(self, token_ids):
        """
        Takes the token_ids as input
        and returns the embeddings.

        To obtain the token ids, use `.tokenize_input()`
        Args:
            token_ids: torch.tensor(B, S)
        Returns:
            embeddings: torch.tensor(B, S, H)
        """

        # get the token embeddings
        x = self.token_embedder(token_ids)

        # apply the positional encoding, if any
        x = self.positional_encodings(x)

        return x

    def tokenize_input(self, input_string, truncate=False, add_eot=True):
        """
        Tokenize an input string.
        """
        token_ids = self.tokenizer.encode(input_string)
        if add_eot:
            token_ids.append(self.eot_token)
        if truncate:
            token_ids = self.truncate([token_ids])[0]
        return token_ids

    def pad_batch(self, token_lists, direction="right"):
        """Pad a list of token lists to the same length,
        and return the padded tensor, and mask tensor.
        Args:
            token_lists: list of lists of tokens
            direction: str
        """
        return self.tokenizer.pad_batch(token_lists, direction=direction)

    def truncate(self, token_lists):
        # get model max length
        max_length = self.model_cfg["context_window"]
        return [token_seq[-max_length:] for token_seq in token_lists]

    def decode(self, tokens):
        """
        Decode a tensor of tokens into a string.
        """
        return self.tokenizer.decode_batch(tokens)