Spaces:
Runtime error
Runtime error
| """Encoder self-attention layer definition.""" | |
| import math | |
| import pdb | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, strtobool | |
| try: | |
| from mamba_ssm.modules.mamba_simple import Mamba, Block | |
| from mamba_ssm.models.mixer_seq_simple import _init_weights | |
| from mamba_ssm.ops.triton.layernorm import RMSNorm | |
| except ImportError: | |
| print("Please install mamba_ssm to use MambaSSM component.") | |
| class MambaBlock(nn.Module): | |
| def __init__(self, in_channels, n_layer=1, d_state=16, d_conv=4, expand=4, bidirectional=False): | |
| super(MambaBlock, self).__init__() | |
| self.forward_blocks = nn.ModuleList([]) | |
| self.forward_norm_f = RMSNorm(in_channels, eps=1e-5) | |
| for i in range(n_layer): | |
| self.forward_blocks.append( | |
| Block( | |
| in_channels, | |
| mixer_cls=partial( | |
| Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand | |
| ), | |
| norm_cls=partial(RMSNorm, eps=1e-5), | |
| fused_add_norm=True, | |
| residual_in_fp32=True, | |
| ) | |
| ) | |
| if bidirectional: | |
| self.backward_blocks = nn.ModuleList([]) | |
| for i in range(n_layer): | |
| self.backward_blocks.append( | |
| Block( | |
| in_channels, | |
| mixer_cls=partial( | |
| Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand | |
| ), | |
| norm_cls=partial(RMSNorm, eps=1e-5), | |
| fused_add_norm=True, | |
| residual_in_fp32=True, | |
| ) | |
| ) | |
| self.backward_norm_f = RMSNorm(in_channels, eps=1e-5) | |
| else: | |
| self.backward_blocks = None | |
| self.apply(partial(_init_weights, n_layer=n_layer)) | |
| def forward(self, input): | |
| for_residual = None | |
| forward_f = input.clone() | |
| for block in self.forward_blocks: | |
| forward_f, for_residual = block(forward_f, for_residual, inference_params=None) | |
| residual = (forward_f + for_residual) if for_residual is not None else forward_f | |
| residual = self.forward_norm_f(residual) | |
| if self.backward_blocks is not None: | |
| back_residual = None | |
| backward_f = torch.flip(input, [1]) | |
| for block in self.backward_blocks: | |
| backward_f, back_residual = block(backward_f, back_residual, inference_params=None) | |
| back_residual = ( | |
| (backward_f + back_residual) if back_residual is not None else backward_f | |
| ) | |
| back_residual = torch.flip(back_residual, [1]) | |
| back_residual = self.backward_norm_f(back_residual) | |
| residual = torch.cat([residual, back_residual], -1) | |
| return residual | |
| class MambaSSM(torch.nn.Module): | |
| def add_arguments(group): | |
| """Add TDNN common arguments.""" | |
| group.add_argument( | |
| "--mamba-num-layers", default=4, type=int, help="Output dim of MambaSSM." | |
| ) | |
| group.add_argument( | |
| "--mamba-input-dim", default=256, type=int, help="Input dim of MambaSSM." | |
| ) | |
| group.add_argument( | |
| "--mamba-output-dim", default=256, type=int, help="Output dim of MambaSSM." | |
| ) | |
| group.add_argument("--mamba-d-state", default=16, type=int, help="d-state of MambaSSM.") | |
| group.add_argument("--mamba-d-conv", default=4, type=int, help="d-conv of MambaSSM.") | |
| group.add_argument("--mamba-expand", default=4, type=int, help="expand of MambaSSM.") | |
| return group | |
| def __init__(self, args): | |
| """Construct an Encoder object.""" | |
| super(MambaSSM, self).__init__() | |
| self.mamb_num_layers = args.mamba_num_layers | |
| self.mamba_input_dim = args.mamba_input_dim | |
| self.mamba_output_dim = args.mamba_output_dim | |
| self.mamba_d_state = args.mamba_d_state | |
| self.mamba_d_conv = args.mamba_d_conv | |
| self.mamba_expand = args.mamba_expand | |
| self.mamba = MambaBlock( | |
| self.mamba_input_dim, | |
| self.mamb_num_layers, | |
| self.mamba_d_state, | |
| self.mamba_d_conv, | |
| self.mamba_expand, | |
| ) | |
| def forward(self, xs, ilens=None, masks=None): | |
| """Embed positions in tensor. | |
| :param torch.Tensor xs: input tensor | |
| :param torch.Tensor masks: input mask | |
| :return: position embedded tensor and mask | |
| :rtype Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| xs_out = self.mamba(xs) | |
| return xs_out.to(xs.dtype), ilens, masks | |