| | |
| | |
| | |
| | import math |
| | from enum import Enum |
| |
|
| | import einops |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| | |
| | from rscd.models.decoderheads.vision_lstm_util import interpolate_sincos, to_ntuple, VitPatchEmbed, VitPosEmbed2d, DropPath |
| |
|
| | class SequenceTraversal(Enum): |
| | ROWWISE_FROM_TOP_LEFT = "rowwise_from_top_left" |
| | ROWWISE_FROM_BOT_RIGHT = "rowwise_from_bot_right" |
| |
|
| |
|
| | def bias_linspace_init_(param: torch.Tensor, start: float = 3.4, end: float = 6.0) -> torch.Tensor: |
| | """Linearly spaced bias init across dimensions.""" |
| | assert param.dim() == 1, f"param must be 1-dimensional (typically a bias), got {param.dim()}" |
| | n_dims = param.shape[0] |
| | init_vals = torch.linspace(start, end, n_dims) |
| | with torch.no_grad(): |
| | param.copy_(init_vals) |
| | return param |
| |
|
| |
|
| | def small_init_(param: torch.Tensor, dim: int) -> torch.Tensor: |
| | """ |
| | Fills the input Tensor with values according to the method described in Transformers without Tears: Improving |
| | the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution. |
| | Adopted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py. |
| | """ |
| | std = math.sqrt(2 / (5 * dim)) |
| | torch.nn.init.normal_(param, mean=0.0, std=std) |
| | return param |
| |
|
| |
|
| | def wang_init_(param: torch.Tensor, dim: int, num_blocks: int): |
| | """ Adopted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py. """ |
| | std = 2 / num_blocks / math.sqrt(dim) |
| | torch.nn.init.normal_(param, mean=0.0, std=std) |
| | return param |
| |
|
| |
|
| | def parallel_stabilized_simple( |
| | queries: torch.Tensor, |
| | keys: torch.Tensor, |
| | values: torch.Tensor, |
| | igate_preact: torch.Tensor, |
| | fgate_preact: torch.Tensor, |
| | lower_triangular_matrix: torch.Tensor = None, |
| | stabilize_rowwise: bool = True, |
| | eps: float = 1e-6, |
| | ) -> torch.Tensor: |
| | """ |
| | This is the mLSTM cell in parallel form. |
| | This version is stabilized. We control the range of exp() arguments by |
| | ensuring that they are always smaller than 0.0 by subtracting the maximum. |
| | |
| | Args: |
| | :param queries: (torch.Tensor) (B, NH, S, DH) |
| | :param keys: (torch.Tensor) (B, NH, S, DH) |
| | :param values: (torch.Tensor) (B, NH, S, DH) |
| | :param igate_preact: (torch.Tensor) (B, NH, S, 1) |
| | :param fgate_preact: (torch.Tensor) (B, NH, S, 1) |
| | :param lower_triangular_matrix: (torch.Tensor) (S,S). Defaults to None. |
| | :param stabilize_rowwise: (bool) Wether to stabilize the combination matrix C rowwise (take maximum per row). |
| | Alternative: Subtract the maximum over all rows. Defaults to True. |
| | :param eps: (float) small constant to avoid division by 0. Defaults to 1e-6. |
| | |
| | Returns: |
| | torch.Tensor: (B, NH, S, DH), h_tilde_state |
| | """ |
| |
|
| | B, NH, S, DH = queries.shape |
| | _dtype, _device = queries.dtype, queries.device |
| |
|
| | |
| | log_fgates = torch.nn.functional.logsigmoid(fgate_preact) |
| | if lower_triangular_matrix is None or S < lower_triangular_matrix.size(-1): |
| | ltr = torch.tril(torch.ones((S, S), dtype=torch.bool, device=_device)) |
| | else: |
| | ltr = lower_triangular_matrix |
| | assert ltr.dtype == torch.bool, f"lower_triangular_matrix must be of dtype bool, got {ltr.dtype}" |
| |
|
| | log_fgates_cumsum = torch.cat( |
| | [ |
| | torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device), |
| | torch.cumsum(log_fgates, dim=-2), |
| | ], |
| | dim=-2, |
| | ) |
| | |
| | |
| | |
| | rep_log_fgates_cumsum = log_fgates_cumsum.repeat(1, 1, 1, S + 1) |
| | |
| | |
| | _log_fg_matrix = rep_log_fgates_cumsum - rep_log_fgates_cumsum.transpose(-2, -1) |
| | |
| | |
| | log_fg_matrix = torch.where(ltr, _log_fg_matrix[:, :, 1:, 1:], -float("inf")) |
| |
|
| | |
| | log_D_matrix = log_fg_matrix + igate_preact.transpose(-2, -1) |
| | |
| | if stabilize_rowwise: |
| | max_log_D, _ = torch.max(log_D_matrix, dim=-1, keepdim=True) |
| | else: |
| | max_log_D = torch.max(log_D_matrix.view(B, NH, -1), dim=-1, keepdim=True)[0].unsqueeze(-1) |
| | |
| | log_D_matrix_stabilized = log_D_matrix - max_log_D |
| | D_matrix = torch.exp(log_D_matrix_stabilized) |
| |
|
| | keys_scaled = keys / math.sqrt(DH) |
| |
|
| | |
| | qk_matrix = queries @ keys_scaled.transpose(-2, -1) |
| | C_matrix = qk_matrix * D_matrix |
| | normalizer = torch.maximum(C_matrix.sum(dim=-1, keepdim=True).abs(), torch.exp(-max_log_D)) |
| | |
| | C_matrix_normalized = C_matrix / (normalizer + eps) |
| |
|
| | |
| | h_tilde_state = C_matrix_normalized @ values |
| |
|
| | return h_tilde_state |
| |
|
| |
|
| | class LinearHeadwiseExpand(nn.Module): |
| | """ |
| | This is a structured projection layer that projects the input to a higher dimension. |
| | It only allows integer up-projection factors, i.e. the output dimension is a multiple of the input dimension. |
| | """ |
| |
|
| | def __init__(self, dim, num_heads, bias=False): |
| | super().__init__() |
| | assert dim % num_heads == 0 |
| | self.dim = dim |
| | self.num_heads = num_heads |
| |
|
| | dim_per_head = dim // num_heads |
| | self.weight = nn.Parameter(torch.empty(num_heads, dim_per_head, dim_per_head)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.empty(dim)) |
| | else: |
| | self.bias = None |
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | nn.init.normal_(self.weight.data, mean=0.0, std=math.sqrt(2 / 5 / self.weight.shape[-1])) |
| | if self.bias is not None: |
| | nn.init.zeros_(self.bias.data) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = einops.rearrange(x, "... (nh d) -> ... nh d", nh=self.num_heads) |
| | x = einops.einsum( |
| | x, |
| | self.weight, |
| | "... nh d, nh out_d d -> ... nh out_d", |
| | ) |
| | x = einops.rearrange(x, "... nh out_d -> ... (nh out_d)") |
| | if self.bias is not None: |
| | x = x + self.bias |
| | return x |
| |
|
| | def extra_repr(self): |
| | return ( |
| | f"dim={self.dim}, " |
| | f"num_heads={self.num_heads}, " |
| | f"bias={self.bias is not None}, " |
| | ) |
| |
|
| |
|
| | class CausalConv1d(nn.Module): |
| | """ |
| | Implements causal depthwise convolution of a time series tensor. |
| | Input: Tensor of shape (B,T,F), i.e. (batch, time, feature) |
| | Output: Tensor of shape (B,T,F) |
| | |
| | Args: |
| | feature_dim: number of features in the input tensor |
| | kernel_size: size of the kernel for the depthwise convolution |
| | causal_conv_bias: whether to use bias in the depthwise convolution |
| | channel_mixing: whether to use channel mixing (i.e. groups=1) or not (i.e. groups=feature_dim) |
| | If True, it mixes the convolved features across channels. |
| | If False, all the features are convolved independently. |
| | """ |
| |
|
| | def __init__(self, dim, kernel_size=4, bias=True): |
| | super().__init__() |
| | self.dim = dim |
| | self.kernel_size = kernel_size |
| | self.bias = bias |
| | |
| | self.pad = kernel_size - 1 |
| | self.conv = nn.Conv1d( |
| | in_channels=dim, |
| | out_channels=dim, |
| | kernel_size=kernel_size, |
| | padding=self.pad, |
| | groups=dim, |
| | bias=bias, |
| | ) |
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | self.conv.reset_parameters() |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | |
| | x = einops.rearrange(x, "b l d -> b d l") |
| | |
| | x = self.conv(x) |
| | x = x[:, :, :-self.pad] |
| | |
| | x = einops.rearrange(x, "b d l -> b l d") |
| | return x |
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False. """ |
| |
|
| | def __init__( |
| | self, |
| | ndim: int = -1, |
| | weight: bool = True, |
| | bias: bool = False, |
| | eps: float = 1e-5, |
| | residual_weight: bool = True, |
| | ): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.zeros(ndim)) if weight else None |
| | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
| | self.eps = eps |
| | self.residual_weight = residual_weight |
| | self.ndim = ndim |
| | self.reset_parameters() |
| |
|
| | @property |
| | def weight_proxy(self) -> torch.Tensor: |
| | if self.weight is None: |
| | return None |
| | if self.residual_weight: |
| | return 1.0 + self.weight |
| | else: |
| | return self.weight |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return F.layer_norm( |
| | x, |
| | normalized_shape=(self.ndim,), |
| | weight=self.weight_proxy, |
| | bias=self.bias, |
| | eps=self.eps, |
| | ) |
| |
|
| | def reset_parameters(self): |
| | if self.weight_proxy is not None: |
| | if self.residual_weight: |
| | nn.init.zeros_(self.weight) |
| | else: |
| | nn.init.ones_(self.weight) |
| | if self.bias is not None: |
| | nn.init.zeros_(self.bias) |
| |
|
| |
|
| | class MultiHeadLayerNorm(LayerNorm): |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | assert x.ndim == 4, "Input must be 4D tensor (B, NH, S, DH)" |
| | B, NH, S, DH = x.shape |
| |
|
| | gn_in_1 = x.transpose(1, 2) |
| | gn_in_2 = gn_in_1.reshape(B * S, NH * DH) |
| | out = F.group_norm( |
| | gn_in_2, |
| | num_groups=NH, |
| | weight=self.weight_proxy, |
| | bias=self.bias, |
| | eps=self.eps, |
| | ) |
| | |
| | out = out.view(B, S, NH, DH).transpose(1, 2) |
| | return out |
| |
|
| |
|
| | class MatrixLSTMCell(nn.Module): |
| | def __init__(self, dim, num_heads): |
| | super().__init__() |
| | self.dim = dim |
| | self.num_heads = num_heads |
| |
|
| | self.igate = nn.Linear(3 * dim, num_heads) |
| | self.fgate = nn.Linear(3 * dim, num_heads) |
| | self.outnorm = MultiHeadLayerNorm(ndim=dim, weight=True, bias=False) |
| | self.causal_mask_cache = {} |
| | self.reset_parameters() |
| |
|
| | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: |
| | B, S, _ = q.shape |
| |
|
| | if_gate_input = torch.cat([q, k, v], dim=-1) |
| | q = q.view(B, S, self.num_heads, -1) |
| | k = k.view(B, S, self.num_heads, -1) |
| | v = v.view(B, S, self.num_heads, -1) |
| |
|
| | q = q.transpose(1, 2) |
| | k = k.transpose(1, 2) |
| | v = v.transpose(1, 2) |
| |
|
| | |
| | igate_preact = self.igate(if_gate_input) |
| | igate_preact = igate_preact.transpose(-1, -2).unsqueeze(-1) |
| | fgate_preact = self.fgate(if_gate_input) |
| | fgate_preact = fgate_preact.transpose(-1, -2).unsqueeze(-1) |
| |
|
| | |
| | if S in self.causal_mask_cache: |
| | causal_mask = self.causal_mask_cache[(S, str(q.device))] |
| | else: |
| | causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device=q.device)) |
| | self.causal_mask_cache[(S, str(q.device))] = causal_mask |
| |
|
| | h_state = parallel_stabilized_simple( |
| | queries=q, |
| | keys=k, |
| | values=v, |
| | igate_preact=igate_preact, |
| | fgate_preact=fgate_preact, |
| | lower_triangular_matrix=causal_mask, |
| | ) |
| |
|
| | h_state_norm = self.outnorm(h_state) |
| | h_state_norm = h_state_norm.transpose(1, 2).reshape(B, S, -1) |
| |
|
| | return h_state_norm |
| |
|
| | def reset_parameters(self): |
| | self.outnorm.reset_parameters() |
| | |
| | torch.nn.init.zeros_(self.fgate.weight) |
| | bias_linspace_init_(self.fgate.bias, start=3.0, end=6.0) |
| | |
| | torch.nn.init.zeros_(self.igate.weight) |
| | torch.nn.init.normal_(self.igate.bias, mean=0.0, std=0.1) |
| |
|
| |
|
| | class ViLLayer(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | direction, |
| | expansion=2, |
| | qkv_block_size=4, |
| | proj_bias=False, |
| | conv_bias=True, |
| | kernel_size=4, |
| | ): |
| | super().__init__() |
| | if dim % qkv_block_size != 0: |
| | qkv_block_size=2 |
| | |
| | self.dim = dim |
| | self.direction = direction |
| | self.expansion = expansion |
| | self.qkv_block_size = qkv_block_size |
| | self.proj_bias = proj_bias |
| | self.conv_bias = conv_bias |
| | self.kernel_size = kernel_size |
| |
|
| | inner_dim = expansion * dim |
| | num_heads = inner_dim // qkv_block_size |
| | self.proj_up = nn.Linear( |
| | in_features=dim, |
| | out_features=2 * inner_dim, |
| | bias=proj_bias, |
| | ) |
| | self.q_proj = LinearHeadwiseExpand( |
| | dim=inner_dim, |
| | num_heads=num_heads, |
| | bias=proj_bias, |
| | ) |
| | self.k_proj = LinearHeadwiseExpand( |
| | dim=inner_dim, |
| | num_heads=num_heads, |
| | bias=proj_bias, |
| | ) |
| | self.v_proj = LinearHeadwiseExpand( |
| | dim=inner_dim, |
| | num_heads=num_heads, |
| | bias=proj_bias, |
| | ) |
| |
|
| | self.conv1d = CausalConv1d( |
| | dim=inner_dim, |
| | kernel_size=kernel_size, |
| | bias=conv_bias, |
| | ) |
| | self.mlstm_cell = MatrixLSTMCell( |
| | dim=inner_dim, |
| | num_heads=qkv_block_size, |
| | ) |
| | self.learnable_skip = nn.Parameter(torch.ones(inner_dim)) |
| |
|
| | self.proj_down = nn.Linear( |
| | in_features=inner_dim, |
| | out_features=dim, |
| | bias=proj_bias, |
| | ) |
| | self.reset_parameters() |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | B, S, _ = x.shape |
| |
|
| | |
| | if self.direction == SequenceTraversal.ROWWISE_FROM_TOP_LEFT: |
| | pass |
| | elif self.direction == SequenceTraversal.ROWWISE_FROM_BOT_RIGHT: |
| | x = x.flip(dims=[1]) |
| | else: |
| | raise NotImplementedError |
| |
|
| | |
| | x_inner = self.proj_up(x) |
| | x_mlstm, z = torch.chunk(x_inner, chunks=2, dim=-1) |
| |
|
| | |
| | x_mlstm_conv = self.conv1d(x_mlstm) |
| | x_mlstm_conv_act = F.silu(x_mlstm_conv) |
| | q = self.q_proj(x_mlstm_conv_act) |
| | k = self.k_proj(x_mlstm_conv_act) |
| | v = self.v_proj(x_mlstm) |
| | h_tilde_state = self.mlstm_cell(q=q, k=k, v=v) |
| | h_tilde_state_skip = h_tilde_state + (self.learnable_skip * x_mlstm_conv_act) |
| |
|
| | |
| | h_state = h_tilde_state_skip * F.silu(z) |
| |
|
| | |
| | x = self.proj_down(h_state) |
| |
|
| | |
| | if self.direction == SequenceTraversal.ROWWISE_FROM_TOP_LEFT: |
| | pass |
| | elif self.direction == SequenceTraversal.ROWWISE_FROM_BOT_RIGHT: |
| | x = x.flip(dims=[1]) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return x |
| |
|
| | def reset_parameters(self): |
| | |
| | small_init_(self.proj_up.weight, dim=self.dim) |
| | if self.proj_up.bias is not None: |
| | nn.init.zeros_(self.proj_up.bias) |
| | |
| | wang_init_(self.proj_down.weight, dim=self.dim, num_blocks=1) |
| | if self.proj_down.bias is not None: |
| | nn.init.zeros_(self.proj_down.bias) |
| |
|
| | nn.init.ones_(self.learnable_skip) |
| |
|
| | def _init_qkv_proj(qkv_proj: LinearHeadwiseExpand): |
| | |
| | small_init_(qkv_proj.weight, dim=self.dim) |
| | if qkv_proj.bias is not None: |
| | nn.init.zeros_(qkv_proj.bias) |
| |
|
| | _init_qkv_proj(self.q_proj) |
| | _init_qkv_proj(self.k_proj) |
| | _init_qkv_proj(self.v_proj) |
| |
|
| | self.mlstm_cell.reset_parameters() |
| |
|
| |
|
| | class ViLBlock(nn.Module): |
| | def __init__(self, dim, direction, drop_path=0.0, norm_bias=False): |
| | super().__init__() |
| | self.dim = dim |
| | self.direction = direction |
| | self.drop_path = drop_path |
| | self.norm_bias = norm_bias |
| |
|
| | self.drop_path = DropPath(drop_prob=drop_path) |
| | self.norm = LayerNorm(ndim=dim, weight=True, bias=norm_bias) |
| | self.layer = ViLLayer(dim=dim, direction=direction) |
| |
|
| | self.reset_parameters() |
| |
|
| | def _forward_path(self, x): |
| | x = self.norm(x) |
| | x = self.layer(x) |
| | return x |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.drop_path(x, self._forward_path) |
| | |
| | return x |
| |
|
| | def reset_parameters(self): |
| | self.layer.reset_parameters() |
| | self.norm.reset_parameters() |
| |
|
| |
|
| | class VisionLSTM(nn.Module): |
| | def __init__( |
| | self, |
| | dim=192, |
| | input_shape=(3, 224, 224), |
| | patch_size=16, |
| | depth=24, |
| | output_shape=(1000,), |
| | mode="classifier", |
| | pooling="bilateral_avg", |
| | drop_path_rate=0.0, |
| | stride=None, |
| | alternation="bidirectional", |
| | drop_path_decay=False, |
| | legacy_norm=False, |
| | ): |
| | super().__init__() |
| | self.input_shape = input_shape |
| | self.output_shape = output_shape |
| | ndim = len(self.input_shape) - 1 |
| | self.patch_size = to_ntuple(patch_size, n=ndim) |
| | self.dim = dim |
| | self.depth = depth |
| | self.stride = stride |
| | self.mode = mode |
| | self.pooling = pooling |
| | self.alternation = alternation |
| | self.drop_path_rate = drop_path_rate |
| | self.drop_path_decay = drop_path_decay |
| |
|
| | |
| | self.patch_embed = VitPatchEmbed( |
| | dim=dim, |
| | stride=stride, |
| | num_channels=self.input_shape[0], |
| | resolution=self.input_shape[1:], |
| | patch_size=self.patch_size, |
| | ) |
| |
|
| | |
| | self.pos_embed = VitPosEmbed2d(seqlens=self.patch_embed.seqlens, dim=dim) |
| |
|
| | |
| | if drop_path_decay and drop_path_rate > 0.: |
| | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
| | else: |
| | dpr = [drop_path_rate] * depth |
| |
|
| | |
| | directions = [] |
| | if alternation == "bidirectional": |
| | for i in range(depth): |
| | if i % 2 == 0: |
| | directions.append(SequenceTraversal.ROWWISE_FROM_TOP_LEFT) |
| | else: |
| | directions.append(SequenceTraversal.ROWWISE_FROM_BOT_RIGHT) |
| | else: |
| | raise NotImplementedError(f"invalid alternation '{alternation}'") |
| |
|
| | |
| | self.blocks = nn.ModuleList( |
| | [ |
| | ViLBlock( |
| | dim=dim, |
| | drop_path=dpr[i], |
| | direction=directions[i], |
| | ) |
| | for i in range(depth) |
| | ] |
| | ) |
| | |
| | if legacy_norm: |
| | self.legacy_norm = LayerNorm(dim, bias=False) |
| | else: |
| | self.legacy_norm = nn.Identity() |
| | self.norm = nn.LayerNorm(dim, eps=1e-6) |
| |
|
| | |
| | if mode is None: |
| | |
| | assert self.output_shape is None |
| | assert self.pooling is None |
| | self.head = None |
| | self.output_shape = (self.patch_embed.num_patches, dim) |
| | elif mode == "classifier": |
| | |
| | assert self.output_shape is not None and len(self.output_shape) == 1, \ |
| | f"define number of classes via output_shape=(num_classes,) (e.g. output_shape=(1000,) for ImageNet-1K" |
| | self.head = nn.Linear(dim, self.output_shape[0]) |
| | |
| | nn.init.trunc_normal_(self.head.weight, std=2e-5) |
| | nn.init.zeros_(self.head.bias) |
| | else: |
| | raise NotImplementedError |
| |
|
| | def load_state_dict(self, state_dict, strict=True): |
| | |
| | old_pos_embed = state_dict["pos_embed.embed"] |
| | if old_pos_embed.shape != self.pos_embed.embed.shape: |
| | state_dict["pos_embed.embed"] = interpolate_sincos(embed=old_pos_embed, seqlens=self.pos_embed.seqlens) |
| | return super().load_state_dict(state_dict=state_dict, strict=strict) |
| |
|
| | @torch.jit.ignore |
| | def no_weight_decay(self): |
| | return {"pos_embed.embed"} |
| |
|
| | def forward(self, x): |
| | |
| | x = self.patch_embed(x) |
| | |
| | x = self.pos_embed(x) |
| |
|
| | |
| | x = einops.rearrange(x, "b ... d -> b (...) d") |
| |
|
| | |
| | for block in self.blocks: |
| | x = block(x) |
| | x = self.legacy_norm(x) |
| |
|
| | |
| | if self.pooling is None: |
| | x = self.norm(x) |
| | elif self.pooling == "bilateral_avg": |
| | |
| | x = (x[:, 0] + x[:, -1]) / 2 |
| | x = self.norm(x) |
| | else: |
| | raise NotImplementedError(f"pooling '{self.pooling}' is not implemented") |
| |
|
| | |
| | if self.head is not None: |
| | x = self.head(x) |
| |
|
| | return x |