| | |
| | |
| |
|
| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from einops import rearrange |
| | import opt_einsum as oe |
| | contract = oe.contract |
| |
|
| | """ Utils for the training loop. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py """ |
| |
|
| | class OptimModule(nn.Module): |
| | """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """ |
| |
|
| | def register(self, name, tensor, lr=None, wd=0.0): |
| | """Register a tensor with a configurable learning rate and 0 weight decay""" |
| |
|
| | if lr == 0.0: |
| | self.register_buffer(name, tensor) |
| | else: |
| | self.register_parameter(name, nn.Parameter(tensor)) |
| |
|
| | optim = {} |
| | if lr is not None: optim["lr"] = lr |
| | if wd is not None: optim["weight_decay"] = wd |
| | setattr(getattr(self, name), "_optim", optim) |
| |
|
| |
|
| | def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): |
| | |
| | seqlen = u.shape[-1] |
| | |
| | fft_size = 2 * seqlen |
| | k_f = torch.fft.rfft(k, n=fft_size) / fft_size |
| | if k_rev is not None: |
| | k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size |
| | k_f = k_f + k_rev_f.conj() |
| | u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) |
| |
|
| | if len(u.shape) > 3: |
| | k_f = k_f.unsqueeze(1) |
| |
|
| | y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] |
| |
|
| | out = y + u * D |
| |
|
| | if gelu: |
| | out = F.gelu(out) |
| | if dropout_mask is not None: |
| | return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) |
| | else: |
| | return out.to(dtype=u.dtype) |
| |
|
| |
|
| | @torch.jit.script |
| | def mul_sum(q, y): |
| | return (q * y).sum(dim=1) |
| |
|
| |
|
| | class Sin(nn.Module): |
| | def __init__(self, dim, w=10, w_mod=1, train_freq=True): |
| | super().__init__() |
| |
|
| | init_tensor = torch.ones(1, dim) |
| | self.freq = ( |
| | nn.Parameter(w * init_tensor) |
| | if train_freq |
| | else w * torch.ones(1, dim) |
| | ) |
| | self.w_mod = w_mod |
| |
|
| | def forward(self, x): |
| | return torch.sin(self.w_mod * self.freq * x) |
| |
|
| |
|
| | class PositionalEmbedding(OptimModule): |
| | def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs): |
| | """Complex exponential positional embeddings for Hyena filters.""" |
| | super().__init__() |
| |
|
| | self.seq_len = seq_len |
| | |
| | t = torch.linspace(0, 1, self.seq_len)[None, :, None] |
| |
|
| | if emb_dim > 1: |
| | bands = (emb_dim - 1) // 2 |
| | |
| | t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] |
| | w = 2 * math.pi * t_rescaled / seq_len |
| |
|
| | f = torch.linspace(1e-4, bands - 1, bands)[None, None] |
| | z = torch.exp(-1j * f * w) |
| | z = torch.cat([t, z.real, z.imag], dim=-1) |
| | self.register("z", z, lr=lr_pos_emb) |
| | self.register("t", t, lr=0.0) |
| |
|
| | def forward(self, L): |
| | return self.z[:, :L], self.t[:, :L] |
| |
|
| |
|
| | class ExponentialModulation(OptimModule): |
| | def __init__( |
| | self, |
| | d_model, |
| | fast_decay_pct=0.3, |
| | slow_decay_pct=1.5, |
| | target=1e-2, |
| | modulation_lr=0.0, |
| | shift: float = 0.0, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.shift = shift |
| | max_decay = math.log(target) / fast_decay_pct |
| | min_decay = math.log(target) / slow_decay_pct |
| | deltas = torch.linspace(min_decay, max_decay, d_model)[None, None] |
| | self.register("deltas", deltas, lr=modulation_lr) |
| |
|
| | def forward(self, t, x): |
| | decay = torch.exp(-t * self.deltas.abs()) |
| | x = x * (decay + self.shift) |
| | return x |
| |
|
| |
|
| | class HyenaFilter(OptimModule): |
| | def __init__( |
| | self, |
| | d_model, |
| | emb_dim=3, |
| | order=16, |
| | seq_len=1024, |
| | lr=1e-3, |
| | lr_pos_emb=1e-5, |
| | dropout=0.0, |
| | w=1, |
| | w_mod=1, |
| | wd=0, |
| | bias=True, |
| | num_inner_mlps=2, |
| | linear_mixer=False, |
| | modulate: bool = True, |
| | normalized=False, |
| | bidirectional=False, |
| | **kwargs, |
| | ): |
| | """ |
| | Implicit long filter with modulation. |
| | |
| | Args: |
| | d_model: number of channels in the input |
| | emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands |
| | order: width of the FFN |
| | num_inner_mlps: number of inner linear layers inside filter MLP |
| | |
| | Note: |
| | filter_dropout is not implemented |
| | """ |
| | super().__init__() |
| | |
| | self.d_model=d_model |
| | self.emb_dim=emb_dim |
| | self.seq_len=seq_len |
| | self.modulate=modulate |
| | self.use_bias = bias |
| | self.bidirectional = bidirectional |
| |
|
| | self.bias = nn.Parameter(torch.randn(self.d_model)) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | act = Sin(dim=order, w=w, w_mod=w_mod) |
| | assert ( |
| | emb_dim % 2 != 0 and emb_dim >= 3 |
| | ), "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)" |
| | self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb) |
| |
|
| | |
| | if linear_mixer is False: |
| | self.implicit_filter = nn.Sequential( |
| | nn.Linear(emb_dim, order), |
| | act, |
| | ) |
| | for i in range(num_inner_mlps): |
| | self.implicit_filter.append(nn.Linear(order, order)) |
| | self.implicit_filter.append(act) |
| | self.implicit_filter.append(nn.Linear(order, d_model, bias=False)) |
| | else: |
| | self.implicit_filter = nn.Sequential( |
| | nn.Linear(emb_dim, d_model, bias=False), |
| | ) |
| |
|
| | if self.bidirectional: |
| | self.implicit_filter_rev = nn.Sequential( |
| | nn.Linear(emb_dim, order), |
| | act, |
| | ) |
| | for i in range(num_inner_mlps): |
| | self.implicit_filter_rev.append(nn.Linear(order, order)) |
| | self.implicit_filter_rev.append(act) |
| | self.implicit_filter_rev.append(nn.Linear(order, d_model, bias=False)) |
| |
|
| | self.modulation = ExponentialModulation(d_model, **kwargs) |
| |
|
| | self.normalized = normalized |
| | for c in self.implicit_filter.children(): |
| | for name, v in c.state_dict().items(): |
| | optim = {"weight_decay": wd, "lr": lr} |
| | setattr(getattr(c, name), "_optim", optim) |
| |
|
| | def filter(self, L, *args, **kwargs): |
| | z, t = self.pos_emb(L) |
| | h = self.implicit_filter(z) |
| | if self.modulate: |
| | h = self.modulation(t, h) |
| | if self.normalized: |
| | h = h / torch.norm(h, dim=-1, p=1, keepdim=True) |
| | return h |
| | |
| | def filter_rev(self, L, *args, **kwargs): |
| | z, t = self.pos_emb(L) |
| | h = self.implicit_filter_rev(z) |
| | if self.modulate: |
| | h = self.modulation(t, h) |
| | if self.normalized: |
| | h = h / torch.norm(h, dim=-1, p=1, keepdim=True) |
| | return h |
| |
|
| | def forward(self, x, L, k_fwd=None, k_rev=None, bias=None, *args, **kwargs): |
| | if k_fwd is None: |
| | k_fwd = self.filter(L) |
| | if self.bidirectional and k_rev is None: |
| | k_rev = self.filter_rev(L) |
| |
|
| | |
| | k_fwd = k_fwd[0] if type(k_fwd) is tuple else k_fwd |
| | if bias is None: |
| | bias = self.bias |
| | bias = bias if self.use_bias else 0 * bias |
| |
|
| | if self.bidirectional: |
| | k_rev = k_rev[0] if type(k_rev) is tuple else k_rev |
| | k = F.pad(k_fwd, (0, L)) \ |
| | + F.pad(k_rev.flip(-1), (L, 0)) |
| | else: |
| | k = k_fwd |
| |
|
| | |
| | y = fftconv_ref( |
| | x, |
| | k, |
| | bias, |
| | dropout_mask=None, |
| | gelu=False, |
| | ) |
| |
|
| | return y.to(dtype=x.dtype) |