| |
|
|
| from dataclasses import dataclass, field |
| from enum import Enum |
| from typing import Optional, Tuple |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from lingua.transformer import FeedForward, InitStdFactor, RMSNorm |
| from lingua.probe import log_stats |
|
|
| from apps.fastRNN.component.rnn_common import conv1d, scan |
|
|
|
|
| @dataclass |
| class BaseHawkArgs: |
| dim: int = 512 |
| n_layers: int = 8 |
| n_heads: int = 1 |
|
|
| multiple_of: int = 256 |
| ffn_dim_multiplier: Optional[float] = None |
| lru_dim_multiplier: Optional[float] = None |
|
|
| conv_size: Optional[int] = None |
|
|
| norm_eps: float = 1e-5 |
|
|
| init_base_std: Optional[float] = None |
| init_std_factor: str = "disabled" |
|
|
|
|
| _MAX_SQRT_GRADIENT: float = 1000.0 |
|
|
|
|
| class SqrtBoundDerivative(torch.autograd.Function): |
| """Computes a square root with a gradient clipped at `_MAX_SQRT_GRADIENT`.""" |
|
|
| @staticmethod |
| def forward(ctx, x: torch.Tensor) -> torch.Tensor: |
| """The forward pass, which is a normal `sqrt`.""" |
| ctx.save_for_backward(x) |
| return torch.sqrt(x) |
|
|
| @staticmethod |
| def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: |
| """The backward pass, which clips the `sqrt` gradient.""" |
| (x,) = ctx.saved_tensors |
| clipped_x_times_4 = torch.clip(4.0 * x, min=1 / (_MAX_SQRT_GRADIENT**2)) |
| return grad_output / torch.sqrt(clipped_x_times_4) |
|
|
|
|
| def sqrt_bounded_derivative(x: torch.Tensor) -> torch.Tensor: |
| return SqrtBoundDerivative.apply(x) |
|
|
|
|
| class RGLRU(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| n_heads: int, |
| head_dim: int, |
| conv_size: Optional[int] = None, |
| ): |
| super().__init__() |
|
|
| assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}" |
|
|
| self.dim = dim |
| self.head_dim = head_dim |
| self.n_heads = n_heads |
| assert ( |
| head_dim * n_heads == dim |
| ), f"dim {dim} must be equal to n_heads {n_heads} * head_dim {head_dim}" |
|
|
| self.c = 8.0 |
|
|
| self.conv_size = conv_size |
| if conv_size is not None: |
| assert (dim % 8 == 0) and ( |
| conv_size in [2, 3, 4] |
| ), f"Causal conv1d only supports conv_size in [2, 3, 4] and hidden_dim/head_dim % 8 == 0, got {dim} and {conv_size}" |
| self.conv_dim = self.dim |
| self.conv_weight = nn.Parameter(torch.empty((self.conv_dim, conv_size))) |
|
|
| self.register_parameter("a", nn.Parameter(torch.empty((head_dim)))) |
|
|
| self.input_gate = nn.Linear(n_heads * head_dim, dim, bias=False) |
|
|
| self.a_gate = nn.Linear(n_heads * head_dim, dim, bias=False) |
|
|
| def forward( |
| self, x: torch.Tensor, tok_idx: torch.Tensor, cu_seqlens: torch.Tensor, impl: str = "parallel" |
| ) -> torch.Tensor: |
| bsz, seqlen, _ = x.shape |
|
|
| if self.conv_size is not None: |
| conv1d_w = log_stats(self.conv_weight, "conv1d.w") |
| x = conv1d( |
| x=x.transpose(1, 2), |
| conv_weight=conv1d_w, |
| tok_idx=tok_idx, |
| cu_seqlens=cu_seqlens, |
| impl=impl, |
| cache=self.cache.conv_cache if hasattr(self, "cache") else None, |
| ).transpose(1, 2) |
|
|
| gate_x = F.sigmoid(self.input_gate(x.view_as(x))) |
| gate_a = F.sigmoid(self.a_gate(x.view_as(x))) |
|
|
| gate_x = gate_x.transpose(1, 2).reshape( |
| bsz * self.n_heads, self.head_dim, seqlen |
| ) |
| gate_a = gate_a.transpose(1, 2).reshape( |
| bsz * self.n_heads, self.head_dim, seqlen |
| ) |
|
|
| a = ( |
| F.softplus(self.a) |
| .unsqueeze(0) |
| .unsqueeze(-1) |
| .expand(bsz * self.n_heads, self.head_dim, seqlen) |
| ) |
|
|
| log_a = -self.c * gate_a * a |
| a = log_a.exp() |
| multiplier = sqrt_bounded_derivative(1.0 - (2.0 * log_a).exp()) |
|
|
| x = x.transpose(1, 2).reshape(bsz * self.n_heads, self.head_dim, seqlen) |
|
|
| h = scan( |
| a=a.contiguous(), |
| b=(multiplier * gate_x * x).contiguous(), |
| cu_seqlens=cu_seqlens, |
| impl=impl, |
| cache=self.cache.state_cache if hasattr(self, "cache") else None, |
| ) |
|
|
| h = h.view(bsz, self.dim, seqlen).transpose(1, 2) |
| h = log_stats(h, "hidden_state") |
|
|
| return h |
|
|
| def reset_parameters(self, init_std, factor): |
| in_init_std = init_std or (self.dim ** (-0.5)) |
| in_init_std = in_init_std / factor |
|
|
| for w in [self.input_gate, self.a_gate]: |
| nn.init.trunc_normal_( |
| w.weight, std=in_init_std, a=-3 * in_init_std, b=3 * in_init_std |
| ) |
|
|
| min_rad, max_rad = 0.9, 0.999 |
| self.a.data.uniform_(min_rad**2 + 1e-8, max_rad**2 + 1e-8) |
| self.a.data.log_().mul_(0.5) |
|
|
| if self.conv_size is not None: |
| conv_std = init_std or (self.conv_size ** (-0.5)) |
| nn.init.trunc_normal_( |
| self.conv_weight, std=conv_std, a=-3 * conv_std, b=3 * conv_std |
| ) |
|
|
|
|
| class RGLRUBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| hidden_dim: int, |
| n_heads: int, |
| multiple_of: int, |
| lru_dim_multiplier: Optional[float], |
| conv_size: Optional[int] = None, |
| ): |
| super().__init__() |
|
|
| if lru_dim_multiplier is not None: |
| hidden_dim = int(lru_dim_multiplier * hidden_dim) |
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
| assert ( |
| hidden_dim % n_heads == 0 |
| ), f"Hidden dim must be divisible by n_heads: {hidden_dim} % {n_heads} != 0" |
|
|
| self.dim = dim |
| self.hidden_dim = hidden_dim |
|
|
| self.wy = nn.Linear( |
| dim, |
| hidden_dim, |
| bias=False, |
| ) |
|
|
| self.wx = nn.Linear( |
| dim, |
| hidden_dim, |
| bias=False, |
| ) |
|
|
| self.rglru = RGLRU( |
| dim=hidden_dim, |
| n_heads=n_heads, |
| head_dim=hidden_dim // n_heads, |
| conv_size=conv_size, |
| ) |
|
|
| self.wo = nn.Linear( |
| hidden_dim, |
| dim, |
| bias=False, |
| ) |
|
|
| def forward( |
| self, x: torch.Tensor, tok_idx: torch.Tensor, cu_seqlens: torch.Tensor, impl: str = "parallel" |
| ) -> torch.Tensor: |
| h = self.rglru(self.wx(x), tok_idx=tok_idx, cu_seqlens=cu_seqlens, impl=impl) |
| h = h * F.silu(self.wy(x)) |
| y = x + self.wo(h) |
|
|
| return y |
|
|
| def init_weights(self, init_std: Optional[float], factor: InitStdFactor): |
| self.rglru.reset_parameters(init_std, factor) |
|
|
| in_init_std = init_std or (self.dim ** (-0.5)) |
| out_init_std = init_std or (self.hidden_dim ** (-0.5)) |
| in_init_std = in_init_std / factor |
| out_init_std = out_init_std / factor |
|
|
| for w in [self.wy, self.wx]: |
| nn.init.trunc_normal_( |
| w.weight, std=in_init_std, a=-3 * in_init_std, b=3 * in_init_std |
| ) |
|
|
| nn.init.trunc_normal_( |
| self.wo.weight, std=out_init_std, a=-3 * out_init_std, b=3 * out_init_std |
| ) |
|
|
|
|
| class HawkBlock(nn.Module): |
| def __init__(self, args: BaseHawkArgs): |
| super().__init__() |
|
|
| self.rlgru_block = RGLRUBlock( |
| dim=args.dim, |
| hidden_dim=int(4 / 3 * args.dim), |
| n_heads=args.n_heads, |
| conv_size=args.conv_size, |
| multiple_of=args.multiple_of, |
| lru_dim_multiplier=args.lru_dim_multiplier, |
| ) |
|
|
| self.feed_forward = FeedForward( |
| dim=args.dim, |
| hidden_dim=4 * args.dim, |
| multiple_of=args.multiple_of, |
| ffn_dim_multiplier=args.ffn_dim_multiplier, |
| ) |
|
|
| self.rlgru_norm = RMSNorm(args.dim, eps=args.norm_eps) |
| self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) |
|
|
| def forward( |
| self, x: torch.Tensor, tok_idx: torch.Tensor, cu_seqlens: torch.Tensor, impl: str = "parallel" |
| ) -> torch.Tensor: |
| x = x + self.rlgru_block(self.rlgru_norm(x), tok_idx=tok_idx, cu_seqlens=cu_seqlens, impl=impl) |
| x = x + self.feed_forward(self.ffn_norm(x)) |
| return x |
|
|
| def init_weights(self, init_std: Optional[float], factor: InitStdFactor): |
| self.rlgru_block.init_weights(init_std, factor) |
| self.rlgru_norm.reset_parameters() |
| self.feed_forward.reset_parameters() |
| self.ffn_norm.reset_parameters() |
|
|
|
|
| class BaseHawk(nn.Module): |
| def __init__(self, args: BaseHawkArgs): |
| super().__init__() |
|
|
| self.dim = args.dim |
| self.init_base_std = args.init_base_std |
| self.init_std_factor = InitStdFactor(args.init_std_factor) |
|
|
| self.layers = nn.ModuleList() |
| for _ in range(args.n_layers): |
| self.layers.append(HawkBlock(args)) |
|
|
| def forward( |
| self, h: torch.Tensor, tok_idx: torch.Tensor, cu_seqlens: torch.Tensor, impl: str = "parallel" |
| ) -> torch.Tensor: |
| for i, layer in enumerate(self.layers): |
| h = layer(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, impl=impl) |
| return h |
|
|
| def reset_parameters(self): |
| pass |
|
|
| def init_weights(self): |
| self.reset_parameters() |
| for depth, layer in enumerate(self.layers): |
| factor = { |
| InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, |
| InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, |
| InitStdFactor.DIM_RATIO: self.dim / 4096, |
| InitStdFactor.DISABLED: 1.0, |
| }[self.init_std_factor] |
|
|
| layer.init_weights(self.init_base_std, factor) |
|
|