File size: 4,878 Bytes
5d2c747 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
"""
A collection of embedding models. A collection model includes
the tokenizer(s), token embeddings and positional encodings
(if necessary).
"""
import torch
from models.components.positional_encoding import build_positional_encodings
class EmbedderInterface(torch.nn.Module):
"""Interface for the embedder component of the model."""
def __init__(self):
super().__init__()
self.eot_token = ...
def forward(self, token_ids: torch.LongTensor):
"""This function should take the token_ids as input,
and return the embeddings."""
raise NotImplementedError
def tokenize_input(self, input_string: str, truncate=False, add_eot=True):
"""This function should take a single input string and returns
the tokenized input.
Args:
input_string: str
truncate: bool - whether to perform (left) truncation
add_eot: bool
Returns:
typically token_ids of shape (S,)
"""
raise NotImplementedError
def decode(self, tokens: torch.LongTensor):
"""This function should decode a tensor of tokens into a string.
For the default implementation of get_sequence_info,
we assume that the tokens are of shape (B, S) and we
decode each sequence in the batch."""
raise NotImplementedError
def pad_batch(self, token_lists, direction="right"):
"""Pad a list of token lists to the same length,
and return the padded tensor, and mask tensor."""
raise NotImplementedError
def truncate(self, token_lists):
"""Truncate a list of token lists, to be shorter than the,
maximum length of the model and return the truncated tensor.
"""
raise NotImplementedError
def get_sequence_info(self, x):
"""
Given a batch of sequences of tokens, return
the character lengths.
Args:
x: torch.tensor(B, S)
"""
sequence_char_lengths = []
# then we decode everything
# batch decode
sequences = self.tokenizer.decode_batch(x)
for seq in sequences:
sequence_char_lengths.append(len(seq))
# obtain the mask for end-of-word and pad tokens
mask = x != self.tokenizer.pad_token
mask = mask & (x != self.tokenizer.eot_token)
return (
sequence_char_lengths,
mask,
)
class Embedder(EmbedderInterface):
"""
A simple and flexible embedding model.
All embedders should inherit from this class.
"""
def __init__(self, model_cfg, tokenizer):
super().__init__()
# build the tokenizer
self.tokenizer = tokenizer
assert self.tokenizer.vocab_size == model_cfg["vocab_size"], f"{model_cfg['vocab_size']=} must match {self.tokenizer.vocab_size=}"
# build the token embeddings
self.token_embedder = torch.nn.Embedding(
num_embeddings=model_cfg["vocab_size"],
embedding_dim=model_cfg["hidden_dim"],
)
# build the positional encodings
self.positional_encodings = build_positional_encodings(model_cfg=model_cfg)
self.eot_token = self.tokenizer.eot_token
self.tokenizer.eos_token = self.eot_token
self.model_cfg = model_cfg
def forward(self, token_ids):
"""
Takes the token_ids as input
and returns the embeddings.
To obtain the token ids, use `.tokenize_input()`
Args:
token_ids: torch.tensor(B, S)
Returns:
embeddings: torch.tensor(B, S, H)
"""
# get the token embeddings
x = self.token_embedder(token_ids)
# apply the positional encoding, if any
x = self.positional_encodings(x)
return x
def tokenize_input(self, input_string, truncate=False, add_eot=True):
"""
Tokenize an input string.
"""
token_ids = self.tokenizer.encode(input_string)
if add_eot:
token_ids.append(self.eot_token)
if truncate:
token_ids = self.truncate([token_ids])[0]
return token_ids
def pad_batch(self, token_lists, direction="right"):
"""Pad a list of token lists to the same length,
and return the padded tensor, and mask tensor.
Args:
token_lists: list of lists of tokens
direction: str
"""
return self.tokenizer.pad_batch(token_lists, direction=direction)
def truncate(self, token_lists):
# get model max length
max_length = self.model_cfg["context_window"]
return [token_seq[-max_length:] for token_seq in token_lists]
def decode(self, tokens):
"""
Decode a tensor of tokens into a string.
"""
return self.tokenizer.decode_batch(tokens)
|