|
|
""" |
|
|
A collection of embedding models. A collection model includes |
|
|
the tokenizer(s), token embeddings and positional encodings |
|
|
(if necessary). |
|
|
""" |
|
|
|
|
|
import torch |
|
|
|
|
|
from models.components.positional_encoding import build_positional_encodings |
|
|
|
|
|
|
|
|
class EmbedderInterface(torch.nn.Module): |
|
|
"""Interface for the embedder component of the model.""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.eot_token = ... |
|
|
|
|
|
def forward(self, token_ids: torch.LongTensor): |
|
|
"""This function should take the token_ids as input, |
|
|
|
|
|
and return the embeddings.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def tokenize_input(self, input_string: str, truncate=False, add_eot=True): |
|
|
"""This function should take a single input string and returns |
|
|
|
|
|
the tokenized input. |
|
|
Args: |
|
|
input_string: str |
|
|
truncate: bool - whether to perform (left) truncation |
|
|
add_eot: bool |
|
|
Returns: |
|
|
typically token_ids of shape (S,) |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def decode(self, tokens: torch.LongTensor): |
|
|
"""This function should decode a tensor of tokens into a string. |
|
|
|
|
|
For the default implementation of get_sequence_info, |
|
|
we assume that the tokens are of shape (B, S) and we |
|
|
decode each sequence in the batch.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def pad_batch(self, token_lists, direction="right"): |
|
|
"""Pad a list of token lists to the same length, |
|
|
and return the padded tensor, and mask tensor.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def truncate(self, token_lists): |
|
|
"""Truncate a list of token lists, to be shorter than the, |
|
|
maximum length of the model and return the truncated tensor. |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def get_sequence_info(self, x): |
|
|
""" |
|
|
Given a batch of sequences of tokens, return |
|
|
the character lengths. |
|
|
Args: |
|
|
x: torch.tensor(B, S) |
|
|
""" |
|
|
|
|
|
sequence_char_lengths = [] |
|
|
|
|
|
|
|
|
sequences = self.tokenizer.decode_batch(x) |
|
|
for seq in sequences: |
|
|
sequence_char_lengths.append(len(seq)) |
|
|
|
|
|
|
|
|
mask = x != self.tokenizer.pad_token |
|
|
mask = mask & (x != self.tokenizer.eot_token) |
|
|
|
|
|
return ( |
|
|
sequence_char_lengths, |
|
|
mask, |
|
|
) |
|
|
|
|
|
|
|
|
class Embedder(EmbedderInterface): |
|
|
""" |
|
|
A simple and flexible embedding model. |
|
|
|
|
|
All embedders should inherit from this class. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_cfg, tokenizer): |
|
|
super().__init__() |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
assert self.tokenizer.vocab_size == model_cfg["vocab_size"], f"{model_cfg['vocab_size']=} must match {self.tokenizer.vocab_size=}" |
|
|
|
|
|
|
|
|
self.token_embedder = torch.nn.Embedding( |
|
|
num_embeddings=model_cfg["vocab_size"], |
|
|
embedding_dim=model_cfg["hidden_dim"], |
|
|
) |
|
|
|
|
|
|
|
|
self.positional_encodings = build_positional_encodings(model_cfg=model_cfg) |
|
|
self.eot_token = self.tokenizer.eot_token |
|
|
self.tokenizer.eos_token = self.eot_token |
|
|
self.model_cfg = model_cfg |
|
|
|
|
|
def forward(self, token_ids): |
|
|
""" |
|
|
Takes the token_ids as input |
|
|
and returns the embeddings. |
|
|
|
|
|
To obtain the token ids, use `.tokenize_input()` |
|
|
Args: |
|
|
token_ids: torch.tensor(B, S) |
|
|
Returns: |
|
|
embeddings: torch.tensor(B, S, H) |
|
|
""" |
|
|
|
|
|
|
|
|
x = self.token_embedder(token_ids) |
|
|
|
|
|
|
|
|
x = self.positional_encodings(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def tokenize_input(self, input_string, truncate=False, add_eot=True): |
|
|
""" |
|
|
Tokenize an input string. |
|
|
""" |
|
|
token_ids = self.tokenizer.encode(input_string) |
|
|
if add_eot: |
|
|
token_ids.append(self.eot_token) |
|
|
if truncate: |
|
|
token_ids = self.truncate([token_ids])[0] |
|
|
return token_ids |
|
|
|
|
|
def pad_batch(self, token_lists, direction="right"): |
|
|
"""Pad a list of token lists to the same length, |
|
|
and return the padded tensor, and mask tensor. |
|
|
Args: |
|
|
token_lists: list of lists of tokens |
|
|
direction: str |
|
|
""" |
|
|
return self.tokenizer.pad_batch(token_lists, direction=direction) |
|
|
|
|
|
def truncate(self, token_lists): |
|
|
|
|
|
max_length = self.model_cfg["context_window"] |
|
|
return [token_seq[-max_length:] for token_seq in token_lists] |
|
|
|
|
|
def decode(self, tokens): |
|
|
""" |
|
|
Decode a tensor of tokens into a string. |
|
|
""" |
|
|
return self.tokenizer.decode_batch(tokens) |
|
|
|