Spaces:
Runtime error
Runtime error
| import argparse | |
| import logging | |
| import sys | |
| import time | |
| from typing import Dict, Optional, Tuple | |
| import numpy as np | |
| import six | |
| import torch | |
| from vita.model.multimodal_encoder.whale.module.component.mamba import MambaSSM | |
| from vita.model.multimodal_encoder.whale.module.component.subsampling import Subsampling | |
| from vita.model.multimodal_encoder.whale.module.component.transformer import Transformer | |
| from vita.model.multimodal_encoder.whale.utils import make_pad_mask | |
| def add_encoder_args(group): | |
| """Add Encoder common arguments.""" | |
| group.add_argument( | |
| "--encoder-layer-config", | |
| type=str, | |
| default="tdnn-dtc", | |
| help="Layer config of encoder. Format layername-layername-..., default(conv1d-fsmn-rnn)", | |
| ) | |
| group.add_argument( | |
| "--encoder-input-dim", | |
| type=int, | |
| default=256, | |
| help="Input dim of encoder. Must equal to the input dim of the first Component (default=40)", | |
| ) | |
| group.add_argument( | |
| "--encoder-output-dim", | |
| type=int, | |
| default=256, | |
| help="Output dim of encoder. Must enqual to the output dim of the last Component ! (default=256)", | |
| ) | |
| # Add args of all kinds of components. | |
| # If you add a new component, DO NOT forget to add args to add_component_args func. | |
| group = Transformer.add_arguments(group) | |
| group = Subsampling.add_arguments(group) | |
| group = MambaSSM.add_arguments(group) | |
| return group | |
| def assign_args_from_dict(args, dict, prefix_key=None): | |
| if prefix_key is not None: | |
| dict = dict[prefix_key] | |
| for k, v in dict.items(): | |
| k_args = k.replace("-", "_") | |
| if hasattr(args, k_args): | |
| setattr(args, k_args, dict[k]) | |
| return args | |
| class whaleEncoder(torch.nn.Module): | |
| def __init__(self, input_dim, overview_conf=None, para_conf=None, global_cmvn=None): | |
| super(whaleEncoder, self).__init__() | |
| parser = argparse.ArgumentParser() | |
| add_encoder_args(parser) | |
| args, _ = parser.parse_known_args() | |
| assign_args_from_dict(args, overview_conf) | |
| # assign_args_from_dict(args, para_conf) | |
| self.config = args.encoder_layer_config.split("-") | |
| encoder_input_dim = args.encoder_input_dim | |
| encoder_output_dim = args.encoder_output_dim | |
| prev_output_dim = encoder_input_dim | |
| prev_component_name = "encoder" | |
| self.enc = torch.nn.ModuleList([]) | |
| for name in self.config: | |
| assign_args_from_dict(args, para_conf[name]) | |
| if len(name.split("_")) == 2: | |
| name = name.split("_")[0] | |
| elif len(name.split("_")) == 1: | |
| name = name | |
| else: | |
| logging.error("WRONG CONFIG! {} is not valid".format("encoder", name)) | |
| sys.exit() | |
| if name == "transformer": | |
| self.enc.append(Transformer(args)) | |
| elif name == "subsampling": | |
| self.enc.append(Subsampling(args)) | |
| elif name == "mamba": | |
| self.enc.append(MambaSSM(args)) | |
| else: | |
| print("{} is not supported now!".format(name)) | |
| return NotImplemented | |
| component_input_dim = getattr(args, name + "_input_dim") | |
| if component_input_dim != prev_output_dim: | |
| # This is the first layer | |
| logging.error( | |
| "WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-input-dim ({})".format( | |
| prev_component_name, prev_output_dim, name, component_input_dim | |
| ) | |
| ) | |
| sys.exit() | |
| prev_output_dim = getattr(args, name + "_output_dim") | |
| prev_component_name = name | |
| self.global_cmvn = global_cmvn | |
| if prev_output_dim != encoder_output_dim: | |
| logging.error( | |
| "WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-output-dim ({}, the last component)".format( | |
| "encoder", encoder_output_dim, name, prev_output_dim | |
| ) | |
| ) | |
| sys.exit() | |
| self._output_size = encoder_output_dim | |
| num_params = sum(p.numel() for p in self.parameters()) | |
| print("the number of whale encoder params: {}M".format(num_params / 1024 / 1024)) | |
| def output_size(self) -> int: | |
| return self._output_size | |
| def forward(self, xs, ilens, decoding_chunk_size=None, num_decoding_left_chunks=None): | |
| # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[List[int]], Optional[Tensor]] | |
| """Encoder forward | |
| :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) | |
| :param torch.Tensor ilens: batch of lengths of input sequences (B) | |
| :return: batch of hidden state sequences (B, Tmax, eprojs) | |
| :rtype: torch.Tensor | |
| """ | |
| if decoding_chunk_size is not None and num_decoding_left_chunks is not None: | |
| for layer in self.enc: | |
| if hasattr(layer, "chunk_size"): | |
| layer.chunk_size = decoding_chunk_size | |
| if hasattr(layer, "left_chunks"): | |
| layer.left_chunks = num_decoding_left_chunks | |
| if hasattr(layer, "transformer_dynamic_chunks"): | |
| layer.transformer_dynamic_chunks = False | |
| assert (len(xs.shape)) == 3 | |
| T = xs.size(1) | |
| masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T) | |
| if self.global_cmvn is not None: | |
| xs = self.global_cmvn(xs) | |
| for module in self.enc: | |
| xs, ilens, masks = module(xs, ilens, masks) | |
| return xs, masks | |
| def infer(self, xs_pad, buffer, buffer_index, buffer_out): | |
| if self.global_cmvn is not None: | |
| xs = self.global_cmvn(xs) | |
| for module in self.enc: | |
| xs_pad, buffer, buffer_index, buffer_out = module.infer( | |
| xs_pad, buffer, buffer_index, buffer_out | |
| ) | |
| return xs_pad, buffer, buffer_index, buffer_out | |
| def infer_hidden(self, xs_pad, buffer, buffer_index, buffer_out, hidden_out): | |
| if self.global_cmvn is not None: | |
| xs = self.global_cmvn(xs) | |
| for module in self.enc: | |
| xs_pad, buffer, buffer_index, buffer_out, hidden_out = module.infer_hidden( | |
| xs_pad, buffer, buffer_index, buffer_out, hidden_out | |
| ) | |
| return xs_pad, buffer, buffer_index, buffer_out, hidden_out | |
| def get_extra_loss(self) -> Dict[str, torch.Tensor]: | |
| return None | |