| | |
| |
|
| | |
| |
|
| | import math |
| | import warnings |
| | from typing import Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| |
|
| | from fla.modules.activations import ACT2FN |
| | from fla.utils import checkpoint |
| |
|
| | try: |
| | from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| | except ImportError: |
| | causal_conv1d_fn = None |
| | causal_conv1d_update = None |
| |
|
| |
|
| | def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None): |
| | seqlen = u.shape[-1] |
| | fft_size = 2 * seqlen |
| | k_f = torch.fft.rfft(k, n=fft_size) / fft_size |
| | if k_rev is not None: |
| | k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size |
| | k_f = k_f + k_rev_f.conj() |
| | u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) |
| |
|
| | if len(u.shape) > 3: |
| | k_f = k_f.unsqueeze(1) |
| | y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] |
| |
|
| | out = y + u |
| | if gelu: |
| | out = F.gelu(out) |
| | if dropout_mask is not None: |
| | return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) |
| | else: |
| | return out.to(dtype=u.dtype) |
| |
|
| |
|
| | @checkpoint |
| | def proj_then_conv1d( |
| | x: torch.Tensor, |
| | proj_weight: torch.Tensor, |
| | conv1d_weight: torch.Tensor, |
| | conv1d_bias: Optional[torch.Tensor] = None, |
| | cache: Optional[torch.Tensor] = None |
| | ) -> torch.Tensor: |
| | |
| | x = rearrange(proj_weight @ rearrange(x, "b t d -> d (b t)"), "d (b t) -> b d t", t=x.shape[-2]) |
| |
|
| | if causal_conv1d_fn is None: |
| | raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.") |
| | if cache is None: |
| | x = causal_conv1d_fn( |
| | x=x, |
| | weight=rearrange(conv1d_weight, "d 1 w -> d w"), |
| | bias=conv1d_bias, |
| | activation="silu", |
| | ).transpose(1, 2) |
| | else: |
| | assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now" |
| | x = x.squeeze(-1) |
| | x = causal_conv1d_update( |
| | x=x, |
| | weight=rearrange(conv1d_weight, "d 1 w -> d w"), |
| | bias=conv1d_bias, |
| | cache=cache, |
| | activation="silu", |
| | ) |
| | return x |
| |
|
| |
|
| | class ShortConvolution(nn.Conv1d): |
| | """ |
| | Simple wrapper around `nn.Conv1d` that accepts dimension last. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | kernel_size: int, |
| | bias: bool = False, |
| | activation: Optional[str] = 'silu', |
| | use_fast_conv1d: Optional[bool] = True |
| | ): |
| | super().__init__( |
| | in_channels=hidden_size, |
| | out_channels=hidden_size, |
| | kernel_size=kernel_size, |
| | groups=hidden_size, |
| | bias=bias, |
| | padding=kernel_size - 1 |
| | ) |
| |
|
| | self.hidden_size = hidden_size |
| | self.activation = None |
| | if activation is not None: |
| | assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet." |
| | self.activation = activation |
| |
|
| | if causal_conv1d_fn is None: |
| | if use_fast_conv1d: |
| | raise RuntimeError( |
| | "Please either install `causal-conv1d>=1.4.0` to enable fast causal short convolution CUDA kernel " |
| | "or set `use_fast_conv1d` to False" |
| | ) |
| | else: |
| | warnings.warn( |
| | "The naive Pytorch verison is very slow in practice, " |
| | "please run `pip install causal-conv1d>=1.4.0` to install fast causal short convolution CUDA kernel", |
| | category=ImportWarning |
| | ) |
| | self.use_fast_conv1d = use_fast_conv1d |
| |
|
| | def extra_repr(self): |
| | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' |
| | ', stride={stride}') |
| | if self.padding != (0,) * len(self.padding): |
| | s += ', padding={padding}' |
| | if self.dilation != (1,) * len(self.dilation): |
| | s += ', dilation={dilation}' |
| | if self.output_padding != (0,) * len(self.output_padding): |
| | s += ', output_padding={output_padding}' |
| | if self.groups != 1: |
| | s += ', groups={groups}' |
| | if self.bias is None: |
| | s += ', bias=False' |
| | if self.padding_mode != 'zeros': |
| | s += ', padding_mode={padding_mode}' |
| | if self.activation is not None: |
| | s += ', activation={activation}' |
| | if not self.use_fast_conv1d: |
| | s += ', use_fast_conv1d={use_fast_conv1d}' |
| | return s.format(**self.__dict__) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | mask: Optional[torch.Tensor] = None, |
| | cache: Optional[torch.Tensor] = None, |
| | output_final_state: bool = False |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Args: |
| | x (`torch.Tensor`): |
| | Tensor of shape `[batch_size, seq_len, hidden_size]` |
| | mask (`Optional[torch.Tensor]`): |
| | Attention mask dealing with padded positions. |
| | cache (`Optional[torch.Tensor]`): |
| | Previous cache tensor of shape `[batch_size, hidden_size, kernel_size]`. |
| | If provided, the cache is updated **inplace**. |
| | output_final_state (Optional[bool]): |
| | Whether to output the final state of shape `[batch_size, hidden_size, kernel_size]`. Default: `False`. |
| | Returns: |
| | Tensor of shape `[batch_size, seq_len, hidden_size]`. |
| | """ |
| |
|
| | batch_size, _, hidden_size = x.shape |
| | if mask is not None: |
| | x = x.mul_(mask.unsqueeze(-1)) |
| | if output_final_state and cache is None: |
| | cache = x.new_zeros(batch_size, hidden_size, self.kernel_size[0]) |
| | if cache is not None and x.shape[1] == 1: |
| | return self.step(x, cache) |
| | x = rearrange(x, "b t d -> b d t") |
| | |
| | if cache is not None: |
| | cache.copy_(F.pad(x, (self.kernel_size[0] - x.shape[-1], 0))) |
| | if self.use_fast_conv1d: |
| | x = causal_conv1d_fn( |
| | x=x, |
| | weight=rearrange(self.weight, "d 1 w -> d w"), |
| | bias=self.bias, |
| | activation=self.activation, |
| | ) |
| | else: |
| | x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]] |
| | if self.activation is not None: |
| | x = ACT2FN[self.activation](x) |
| | return rearrange(x, "b d t -> b t d"), cache |
| |
|
| | def step( |
| | self, |
| | x: torch.Tensor, |
| | cache: torch.Tensor |
| | ): |
| | assert x.shape[1] == 1, "Only support decoding with 1 token at a time for now" |
| |
|
| | x = x.squeeze(1) |
| | if self.use_fast_conv1d: |
| | x = causal_conv1d_update( |
| | x=x, |
| | conv_state=cache, |
| | weight=rearrange(self.weight, "d 1 w -> d w"), |
| | bias=self.bias, |
| | activation=self.activation, |
| | ) |
| | else: |
| | dtype = x.dtype |
| | cache.copy_(torch.roll(cache, shifts=-1, dims=-1)) |
| | cache[:, :, -1] = x |
| | x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1) |
| | if self.bias is not None: |
| | x = x + self.bias |
| | if self.activation is not None: |
| | x = ACT2FN[self.activation](x).to(dtype=dtype) |
| | return x.unsqueeze(1), cache |
| |
|
| | @property |
| | def state_size(self) -> int: |
| | return self.hidden_size * self.kernel_size |
| |
|
| |
|
| | class LongConvolution(nn.Module): |
| | """ |
| | LongConvolution applies a convolution operation on the input tensor using a fixed |
| | filter of length max_len. |
| | The filter is learned during training and is applied using FFT convolution. |
| | Args: |
| | hidden_size (int): The number of expected features in the input and output. |
| | max_len (int): The maximum sequence length. |
| | Returns: |
| | y: [batch_size, seq_len, hidden_size] tensor |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | max_len: int, |
| | **kwargs, |
| | ): |
| | """ |
| | Initializes the LongConvolution module. |
| | Args: |
| | hidden_size (int): The number of expected features in the input and output. |
| | max_len (int): The maximum sequence length. |
| | """ |
| | super().__init__() |
| | self.hidden_size = hidden_size |
| | self.filter = nn.Parameter(torch.randn(self.hidden_size, max_len), requires_grad=True) |
| |
|
| | def forward(self, x: torch.Tensor, *args, **kwargs): |
| | """ |
| | Applies the LongConvolution operation on the input tensor. |
| | Args: |
| | x: [batch_size, seq_len, hidden_size] tensor |
| | Returns: |
| | y: [batch_size, seq_len, hidden_size] tensor |
| | """ |
| | x = x.transpose(1, 2) |
| | y = fft_conv(x, self.filter, dropout_mask=None, gelu=False) |
| | y = y.transpose(1, 2) |
| | return y.to(dtype=x.dtype) |
| |
|
| |
|
| | class PositionalEmbedding(nn.Module): |
| | def __init__(self, emb_dim: int, seq_len: int, **kwargs): |
| | """Complex exponential positional embeddings for implicit long convolution filters.""" |
| | super().__init__() |
| |
|
| | self.seq_len = seq_len |
| | |
| | t = torch.linspace(0, 1, self.seq_len)[None, :, None] |
| |
|
| | if emb_dim > 1: |
| | bands = (emb_dim - 1) // 2 |
| | |
| | t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] |
| | w = 2 * math.pi * t_rescaled / seq_len |
| |
|
| | f = torch.linspace(1e-4, bands - 1, bands)[None, None] |
| | z = torch.exp(-1j * f * w) |
| | z = torch.cat([t, z.real, z.imag], dim=-1) |
| | self.z = nn.Parameter(z, requires_grad=False) |
| |
|
| | def forward(self, L): |
| | return self.z[:, :L] |
| |
|
| |
|
| | class ImplicitLongConvolution(nn.Module): |
| | """ |
| | Long convolution with implicit filter parameterized by an MLP. |
| | |
| | Args: |
| | hidden_size (int): |
| | The number of expected features in the input and output. |
| | max_len (int): |
| | The maximum sequence length. |
| | d_emb (Optional[int]): |
| | The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine). |
| | Defaults to 3. |
| | d_hidden (Optional[int]): |
| | The number of features in the hidden layer of the MLP. Defaults to 16. |
| | |
| | Attributes: |
| | pos_emb (`PositionalEmbedding`): The positional embedding layer. |
| | mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter. |
| | |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | max_len: int, |
| | d_emb: int = 3, |
| | d_hidden: int = 16, |
| | **kwargs, |
| | ): |
| | """ |
| | Long convolution with implicit filter parameterized by an MLP. |
| | |
| | |
| | """ |
| | super().__init__() |
| | self.hidden_size = hidden_size |
| | self.d_emb = d_emb |
| |
|
| | assert ( |
| | d_emb % 2 != 0 and d_emb >= 3 |
| | ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)" |
| | self.pos_emb = PositionalEmbedding(d_emb, max_len) |
| |
|
| | |
| | self.mlp = nn.Sequential( |
| | nn.Linear(d_emb, d_hidden), |
| | torch.nn.ReLU(), |
| | nn.Linear(d_hidden, hidden_size), |
| | ) |
| |
|
| | def filter(self, seq_len: int, *args, **kwargs): |
| | k = self.mlp(self.pos_emb(seq_len)) |
| |
|
| | return k.transpose(1, 2) |
| |
|
| | def forward(self, x: torch.Tensor, *args, **kwargs): |
| | """ |
| | Args: |
| | x: [batch_size, seq_len, hidden_size] tensor |
| | Returns: |
| | y: [batch_size, seq_len, hidden_size] tensor |
| | """ |
| | x = x.transpose(1, 2) |
| | k = self.filter(x.shape[-1]) |
| | y = fft_conv(x, k, dropout_mask=None, gelu=False) |
| |
|
| | y = y.transpose(1, 2) |
| | return y.to(dtype=x.dtype) |
| |
|