| """This file contains the definition of utility functions to group tokens.""" | |
| import math | |
| import torch | |
| def combine_factorized_tokens( | |
| tokens: torch.Tensor, codebook_size: int, splits: int | |
| ) -> torch.Tensor: | |
| """ | |
| Combine the tokens into a single token. | |
| Args: | |
| tokens -> torch.Tensor: Tensor of shape (batch_size, n, m). | |
| codebook_size -> int: The size of the codebook. | |
| splits -> int: Number of splits. | |
| Returns: | |
| combined_tokens -> torch.Tensor: Tensor of shape (batch_size, n). | |
| """ | |
| combined_tokens = torch.zeros( | |
| (tokens.shape[0], tokens.shape[1]), device=tokens.device | |
| ) | |
| bit_shift = int(math.log2(codebook_size)) // splits | |
| for i in range(splits): | |
| combined_tokens += tokens[..., i] << (i * bit_shift) | |
| return combined_tokens | |
| def split_factorized_tokens( | |
| tokens: torch.Tensor, codebook_size: int, splits: int | |
| ) -> torch.Tensor: | |
| """ | |
| Split the tokens into multiple tokens. | |
| Args: | |
| tokens -> torch.Tensor: Tensor of shape (batch_size, n). | |
| codebook_size -> int: The size of the codebook. | |
| splits -> int: Number of splits. | |
| Returns: | |
| split_tokens -> torch.Tensor: Tensor of shape (batch_size, n, m). | |
| """ | |
| bit_shift = int(math.log2(codebook_size)) // splits | |
| bit_mask = (1 << bit_shift) - 1 | |
| split_tokens = [] | |
| for i in range(splits): | |
| split_tokens.append((tokens & (bit_mask << (i * bit_shift))) >> (i * bit_shift)) | |
| return torch.stack(split_tokens, dim=2) | |
| if __name__ == "__main__": | |
| tokens = torch.randint(0, 1023, (1, 16)) | |
| split_tokens = split_factorized_tokens(tokens, 1024, 1) | |
| assert split_tokens.shape == (1, 16, 1) | |
| assert split_tokens.dtype == torch.int64 | |
| combined_tokens = combine_factorized_tokens(split_tokens, 1024, 1) | |
| assert (tokens == combined_tokens).all() | |
| split_tokens = split_factorized_tokens(tokens, 1024, 2) | |
| combined_tokens = combine_factorized_tokens(split_tokens, 1024, 2) | |
| assert split_tokens.shape == (1, 16, 2) | |
| assert (tokens == combined_tokens).all(), f"{tokens} != {combined_tokens}" | |
| assert (torch.bitwise_right_shift(tokens, 5) == split_tokens[..., 1]).all() | |
| assert (tokens & 31 == split_tokens[..., 0]).all() | |