| | from concurrent.futures import ProcessPoolExecutor |
| | from contextlib import contextmanager |
| | from functools import wraps, lru_cache |
| | import hashlib |
| | import json |
| | import logging |
| | from pathlib import Path |
| | import typing as tp |
| | import math |
| | from torch import nn |
| | import typing as tp |
| | from functools import partial |
| | import torch.nn.functional as F |
| | import flashy |
| | import flashy.distrib |
| | import omegaconf |
| | import torch |
| | from torch.nn.utils.rnn import pad_sequence |
| |
|
| | def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor: |
| | """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). |
| | For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]] |
| | |
| | Args: |
| | lengths (torch.Tensor): tensor with lengths |
| | max_len (int): can set the max length manually. Defaults to None. |
| | Returns: |
| | torch.Tensor: mask with 0s where there is pad tokens else 1s |
| | """ |
| | assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." |
| | final_length = lengths.max().item() if not max_len else max_len |
| | final_length = max(final_length, 1) |
| | return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None] |
| |
|
| |
|
| |
|
| | def dict_from_config(cfg: omegaconf.DictConfig) -> dict: |
| | """Convenience function to map an omegaconf configuration to a dictionary. |
| | |
| | Args: |
| | cfg (omegaconf.DictConfig): Original configuration to map to dict. |
| | Returns: |
| | dict: Config as dictionary object. |
| | """ |
| | dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) |
| | assert isinstance(dct, dict) |
| | return dct |
| |
|
| | def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: |
| | """Create normalization module for transformer encoder layer. |
| | |
| | Args: |
| | norm_type (str): Normalization method. |
| | dim (int): Dimension of the normalized layer. |
| | **kwargs (dict): Additional parameters for normalization layer. |
| | Returns: |
| | nn.Module: Normalization module. |
| | """ |
| | if norm_type == 'layer_norm': |
| | return nn.LayerNorm(dim, eps=1e-5, **kwargs) |
| | else: |
| | raise ValueError(f"Unknown norm type: {norm_type}") |
| |
|
| | def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None): |
| | """LM layer initialization. |
| | Inspired from xlformers: https://github.com/fairinternal/xlformers |
| | |
| | Args: |
| | method (str): Method name for init function. Valid options are: |
| | 'gaussian', 'uniform'. |
| | input_dim (int): Input dimension of the initialized module. |
| | init_depth (int, optional): Optional init depth value used to rescale |
| | the standard deviation if defined. |
| | """ |
| | |
| | std = 1 / math.sqrt(input_dim) |
| | |
| | if init_depth is not None: |
| | std = std / math.sqrt(2 * init_depth) |
| |
|
| | if method == 'gaussian': |
| | return partial( |
| | torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std |
| | ) |
| | elif method == 'uniform': |
| | bound = math.sqrt(3) * std |
| | return partial(torch.nn.init.uniform_, a=-bound, b=bound) |
| | else: |
| | raise ValueError("Unsupported layer initialization method") |
| |
|
| | def init_layer(m: nn.Module, |
| | method: str, |
| | init_depth: tp.Optional[int] = None, |
| | zero_bias_init: bool = False): |
| | """Wrapper around ``get_init_fn`` for proper initialization of LM modules. |
| | |
| | Args: |
| | m (nn.Module): Module to initialize. |
| | method (str): Method name for the init function. |
| | init_depth (int, optional): Optional init depth value used to rescale |
| | the standard deviation if defined. |
| | zero_bias_init (bool): Whether to initialize the bias to 0 or not. |
| | """ |
| | if isinstance(m, nn.Linear): |
| | init_fn = get_init_fn(method, m.in_features, init_depth=init_depth) |
| | if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: |
| | weight = m.weight.float() |
| | init_fn(weight) |
| | m.weight.data[:] = weight.half() |
| | else: |
| | init_fn(m.weight) |
| | if zero_bias_init and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.Embedding): |
| | init_fn = get_init_fn(method, m.embedding_dim, init_depth=None) |
| | if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: |
| | weight = m.weight.float() |
| | init_fn(weight) |
| | m.weight.data[:] = weight.half() |
| | else: |
| | init_fn(m.weight) |
| |
|
| | def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
| | """Get a list of tensors and collate them to a single tensor. according to the following logic: |
| | - `dim` specifies the time dimension which will be stacked and padded. |
| | - The output will contain 1 new dimension (dimension index 0) which will be the size of |
| | of the original list. |
| | |
| | Args: |
| | tensors (tp.List[torch.Tensor]): List of tensors to collate. |
| | dim (int): Dimension which will be stacked and padded. |
| | Returns: |
| | tp.Tuple[torch.Tensor, torch.Tensor]: |
| | torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension |
| | (dimension index 0) which will be the size of the original list. |
| | torch.Tensor: Tensor containing length of original tensor sizes (without padding). |
| | """ |
| | tensors = [x.transpose(0, dim) for x in tensors] |
| | lens = torch.LongTensor([len(x) for x in tensors]) |
| | padded_tensors = pad_sequence(tensors) |
| | padded_tensors = padded_tensors.transpose(0, 1) |
| | padded_tensors = padded_tensors.transpose(1, dim + 1) |
| | return padded_tensors, lens |
| |
|
| | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: |
| | """Sample next token from top K values along the last dimension of the input probs tensor. |
| | |
| | Args: |
| | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. |
| | k (int): The k in “top-k”. |
| | Returns: |
| | torch.Tensor: Sampled tokens. |
| | """ |
| | top_k_value, _ = torch.topk(probs, k, dim=-1) |
| | min_value_top_k = top_k_value[..., [-1]] |
| | probs *= (probs >= min_value_top_k).float() |
| | probs.div_(probs.sum(dim=-1, keepdim=True)) |
| | next_token = multinomial(probs, num_samples=1) |
| | return next_token |
| |
|
| | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: |
| | """Sample next token from top P probabilities along the last dimension of the input probs tensor. |
| | |
| | Args: |
| | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. |
| | p (int): The p in “top-p”. |
| | Returns: |
| | torch.Tensor: Sampled tokens. |
| | """ |
| | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
| | probs_sum = torch.cumsum(probs_sort, dim=-1) |
| | mask = probs_sum - probs_sort > p |
| | probs_sort *= (~mask).float() |
| | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
| | next_token = multinomial(probs_sort, num_samples=1) |
| | next_token = torch.gather(probs_idx, -1, next_token) |
| | return next_token |
| |
|
| | def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): |
| | """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. |
| | |
| | Args: |
| | input (torch.Tensor): The input tensor containing probabilities. |
| | num_samples (int): Number of samples to draw. |
| | replacement (bool): Whether to draw with replacement or not. |
| | Keywords args: |
| | generator (torch.Generator): A pseudorandom number generator for sampling. |
| | Returns: |
| | torch.Tensor: Last dimension contains num_samples indices |
| | sampled from the multinomial probability distribution |
| | located in the last dimension of tensor input. |
| | """ |
| | input_ = input.reshape(-1, input.shape[-1]) |
| | output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) |
| | output = output_.reshape(*list(input.shape[:-1]), -1) |
| | return output |