Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import typing as tp | |
| def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor: | |
| """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). | |
| For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]] | |
| Args: | |
| lengths (torch.Tensor): tensor with lengths | |
| max_len (int): can set the max length manually. Defaults to None. | |
| Returns: | |
| torch.Tensor: mask with 0s where there is pad tokens else 1s | |
| """ | |
| assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." | |
| final_length = lengths.max().item() if not max_len else max_len | |
| final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor | |
| return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None] | |
| def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, | |
| dtype: torch.dtype = torch.float32) -> torch.Tensor: | |
| """Create sinusoidal positional embedding, with shape `[B, T, C]`. | |
| Args: | |
| positions (torch.Tensor): LongTensor of positions. | |
| dim (int): Dimension of the embedding. | |
| max_period (float): Maximum period of the cosine/sine functions. | |
| dtype (torch.dtype or str): dtype to use to generate the embedding. | |
| Returns: | |
| torch.Tensor: Sinusoidal positional embedding. | |
| """ | |
| # We aim for BTC format | |
| assert dim % 2 == 0 | |
| half_dim = dim // 2 | |
| positions = positions.to(dtype) | |
| adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) | |
| max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point | |
| phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) | |
| # phase = phase.to(torch.bfloat16) | |
| return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) | |
| def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: | |
| """Create normalization module for transformer encoder layer. | |
| Args: | |
| norm_type (str): Normalization method. | |
| dim (int): Dimension of the normalized layer. | |
| **kwargs (dict): Additional parameters for normalization layer. | |
| Returns: | |
| nn.Module: Normalization module. | |
| """ | |
| if norm_type == 'layer_norm': | |
| return nn.LayerNorm(dim, eps=1e-5, **kwargs) | |
| else: | |
| raise ValueError(f"Unknown norm type: {norm_type}") |