tinyllm_canonical / models /embedding_models.py
stegsoph's picture
Upload folder using huggingface_hub
5d2c747 verified
"""
A collection of embedding models. A collection model includes
the tokenizer(s), token embeddings and positional encodings
(if necessary).
"""
import torch
from models.components.positional_encoding import build_positional_encodings
class EmbedderInterface(torch.nn.Module):
"""Interface for the embedder component of the model."""
def __init__(self):
super().__init__()
self.eot_token = ...
def forward(self, token_ids: torch.LongTensor):
"""This function should take the token_ids as input,
and return the embeddings."""
raise NotImplementedError
def tokenize_input(self, input_string: str, truncate=False, add_eot=True):
"""This function should take a single input string and returns
the tokenized input.
Args:
input_string: str
truncate: bool - whether to perform (left) truncation
add_eot: bool
Returns:
typically token_ids of shape (S,)
"""
raise NotImplementedError
def decode(self, tokens: torch.LongTensor):
"""This function should decode a tensor of tokens into a string.
For the default implementation of get_sequence_info,
we assume that the tokens are of shape (B, S) and we
decode each sequence in the batch."""
raise NotImplementedError
def pad_batch(self, token_lists, direction="right"):
"""Pad a list of token lists to the same length,
and return the padded tensor, and mask tensor."""
raise NotImplementedError
def truncate(self, token_lists):
"""Truncate a list of token lists, to be shorter than the,
maximum length of the model and return the truncated tensor.
"""
raise NotImplementedError
def get_sequence_info(self, x):
"""
Given a batch of sequences of tokens, return
the character lengths.
Args:
x: torch.tensor(B, S)
"""
sequence_char_lengths = []
# then we decode everything
# batch decode
sequences = self.tokenizer.decode_batch(x)
for seq in sequences:
sequence_char_lengths.append(len(seq))
# obtain the mask for end-of-word and pad tokens
mask = x != self.tokenizer.pad_token
mask = mask & (x != self.tokenizer.eot_token)
return (
sequence_char_lengths,
mask,
)
class Embedder(EmbedderInterface):
"""
A simple and flexible embedding model.
All embedders should inherit from this class.
"""
def __init__(self, model_cfg, tokenizer):
super().__init__()
# build the tokenizer
self.tokenizer = tokenizer
assert self.tokenizer.vocab_size == model_cfg["vocab_size"], f"{model_cfg['vocab_size']=} must match {self.tokenizer.vocab_size=}"
# build the token embeddings
self.token_embedder = torch.nn.Embedding(
num_embeddings=model_cfg["vocab_size"],
embedding_dim=model_cfg["hidden_dim"],
)
# build the positional encodings
self.positional_encodings = build_positional_encodings(model_cfg=model_cfg)
self.eot_token = self.tokenizer.eot_token
self.tokenizer.eos_token = self.eot_token
self.model_cfg = model_cfg
def forward(self, token_ids):
"""
Takes the token_ids as input
and returns the embeddings.
To obtain the token ids, use `.tokenize_input()`
Args:
token_ids: torch.tensor(B, S)
Returns:
embeddings: torch.tensor(B, S, H)
"""
# get the token embeddings
x = self.token_embedder(token_ids)
# apply the positional encoding, if any
x = self.positional_encodings(x)
return x
def tokenize_input(self, input_string, truncate=False, add_eot=True):
"""
Tokenize an input string.
"""
token_ids = self.tokenizer.encode(input_string)
if add_eot:
token_ids.append(self.eot_token)
if truncate:
token_ids = self.truncate([token_ids])[0]
return token_ids
def pad_batch(self, token_lists, direction="right"):
"""Pad a list of token lists to the same length,
and return the padded tensor, and mask tensor.
Args:
token_lists: list of lists of tokens
direction: str
"""
return self.tokenizer.pad_batch(token_lists, direction=direction)
def truncate(self, token_lists):
# get model max length
max_length = self.model_cfg["context_window"]
return [token_seq[-max_length:] for token_seq in token_lists]
def decode(self, tokens):
"""
Decode a tensor of tokens into a string.
"""
return self.tokenizer.decode_batch(tokens)