| import argparse |
| import math |
| from typing import List, Tuple |
|
|
| from model import MultiKDModel |
| from scaling import ScheduledFloat |
| from subsampling import Conv2dSubsampling |
| from zipformer import Zipformer2 |
|
|
| import torchaudio |
| from torchaudio.compliance.kaldi import fbank |
| import torch |
| from torch import Tensor |
| import torch.nn as nn |
|
|
| from utilities import make_pad_mask, str2bool, ZipformerConfig |
|
|
| LOG_EPS = math.log(1e-10) |
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser( |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| ) |
|
|
| parser.add_argument( |
| "--model-version", |
| type=str, |
| default="600m_uniform_out_ds1", |
| ) |
|
|
| parser.add_argument( |
| "--causal", |
| type=str2bool, |
| default=False, |
| help="If True, use causal version of model.", |
| ) |
|
|
| parser.add_argument( |
| "--chunk-size", |
| type=str, |
| default="16,32,64,-1", |
| help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " |
| " Must be just -1 if --causal=False", |
| ) |
|
|
| parser.add_argument( |
| "--left-context-frames", |
| type=str, |
| default="64,128,256,-1", |
| help="Maximum left-contexts for causal training, measured in frames which will " |
| "be converted to a number of chunks. If splitting into chunks, " |
| "chunk left-context frames will be chosen randomly from this list; else not relevant.", |
| ) |
|
|
| parser.add_argument( |
| "--ckpt-path", |
| type=str, |
| required=True, |
| ) |
|
|
| parser.add_argument( |
| "--audio", |
| type=str, |
| required=True, |
| help="The path to the audio" |
| ) |
|
|
| return parser |
|
|
| def _to_int_tuple(s: str): |
| return tuple(map(int, s.split(","))) |
|
|
| def get_encoder_embed(params) -> nn.Module: |
| encoder_embed = Conv2dSubsampling( |
| in_channels=128, |
| out_channels=_to_int_tuple(params.encoder_dim)[0], |
| dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), |
| ) |
| return encoder_embed |
|
|
| def get_encoder_model(params) -> nn.Module: |
| encoder = Zipformer2( |
| output_downsampling_factor=params.output_downsampling_factor, |
| downsampling_factor=_to_int_tuple(params.downsampling_factor), |
| num_encoder_layers=_to_int_tuple(params.num_encoder_layers), |
| encoder_dim=_to_int_tuple(params.encoder_dim), |
| encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), |
| query_head_dim=_to_int_tuple("32"), |
| pos_head_dim=_to_int_tuple("4"), |
| value_head_dim=_to_int_tuple("12"), |
| pos_dim=params.pos_dim, |
| num_heads=_to_int_tuple(params.num_heads), |
| feedforward_dim=_to_int_tuple(params.feedforward_dim), |
| cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), |
| dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), |
| warmup_batches=4000.0, |
| causal=params.causal, |
| chunk_size=_to_int_tuple(params.chunk_size), |
| left_context_frames=_to_int_tuple(params.left_context_frames), |
| ) |
| return encoder |
|
|
| def get_params(args): |
| params = ZipformerConfig() |
| params.chunk_size = args.chunk_size |
| params.left_context_frames = args.left_context_frames |
| |
| model_version = args.model_version |
| if model_version == "600m_uniform_out_ds1": |
| params.output_downsampling_factor = 1 |
| params.downsampling_factor = "1,2,4,8,4,2,1" |
| params.num_encoder_layers = "1,2,3,4,1,1,1" |
| params.feedforward_dim = "3840,3840,3840,3840,3840,3840,3840" |
| params.encoder_dim = "1280,1280,1280,1280,1280,1280,1280" |
| params.encoder_unmasked_dim = "768,768,768,768,768,768,768" |
| params.cnn_module_kernel = "31,31,15,15,15,31,31" |
| params.num_heads = "8,8,8,8,8,8,8" |
| elif model_version == "600m_uniform_out_ds2": |
| params.output_downsampling_factor = 2 |
| params.downsampling_factor = "1,2,4,8,4,2,1" |
| params.num_encoder_layers = "1,2,3,4,1,1,1" |
| params.feedforward_dim = "3840,3840,3840,3840,3840,3840,3840" |
| params.encoder_dim = "1280,1280,1280,1280,1280,1280,1280" |
| params.encoder_unmasked_dim = "768,768,768,768,768,768,768" |
| params.cnn_module_kernel = "31,31,15,15,15,31,31" |
| params.num_heads = "8,8,8,8,8,8,8" |
| else: |
| raise ValueError() |
| return params |
|
|
| def get_model(model_version) -> nn.Module: |
| |
| |
| params = get_params(model_version) |
| encoder_embed = get_encoder_embed(params) |
| encoder = get_encoder_model(params) |
| print(params) |
|
|
| model = MultiKDModel( |
| encoder_embed=encoder_embed, |
| encoder=encoder, |
| encoder_dim=max(_to_int_tuple(params.encoder_dim)), |
| num_codebooks=0, |
| ) |
|
|
| return model |
|
|
| def get_init_states( |
| model: nn.Module, |
| batch_size: int = 1, |
| device: torch.device = torch.device("cpu"), |
| ) -> List[torch.Tensor]: |
| """ |
| Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] |
| is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). |
| states[-2] is the cached left padding for ConvNeXt module, |
| of shape (batch_size, num_channels, left_pad, num_freqs) |
| states[-1] is processed_lens of shape (batch,), which records the number |
| of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. |
| """ |
| states = model.encoder.get_init_states(batch_size, device) |
|
|
| embed_states = model.encoder_embed.get_init_states(batch_size, device) |
| states.append(embed_states) |
|
|
| processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) |
| states.append(processed_lens) |
|
|
| return states |
|
|
| def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: |
| """Stack list of zipformer states that correspond to separate utterances |
| into a single emformer state, so that it can be used as an input for |
| zipformer when those utterances are formed into a batch. |
| |
| Args: |
| state_list: |
| Each element in state_list corresponding to the internal state |
| of the zipformer model for a single utterance. For element-n, |
| state_list[n] is a list of cached tensors of all encoder layers. For layer-i, |
| state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, |
| cached_val2, cached_conv1, cached_conv2). |
| state_list[n][-2] is the cached left padding for ConvNeXt module, |
| of shape (batch_size, num_channels, left_pad, num_freqs) |
| state_list[n][-1] is processed_lens of shape (batch,), which records the number |
| of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. |
| |
| Note: |
| It is the inverse of :func:`unstack_states`. |
| """ |
| batch_size = len(state_list) |
| assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) |
| tot_num_layers = (len(state_list[0]) - 2) // 6 |
|
|
| batch_states = [] |
| for layer in range(tot_num_layers): |
| layer_offset = layer * 6 |
| |
| cached_key = torch.cat( |
| [state_list[i][layer_offset] for i in range(batch_size)], dim=1 |
| ) |
| |
| cached_nonlin_attn = torch.cat( |
| [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 |
| ) |
| |
| cached_val1 = torch.cat( |
| [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 |
| ) |
| |
| cached_val2 = torch.cat( |
| [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 |
| ) |
| |
| cached_conv1 = torch.cat( |
| [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 |
| ) |
| |
| cached_conv2 = torch.cat( |
| [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 |
| ) |
| batch_states += [ |
| cached_key, |
| cached_nonlin_attn, |
| cached_val1, |
| cached_val2, |
| cached_conv1, |
| cached_conv2, |
| ] |
|
|
| cached_embed_left_pad = torch.cat( |
| [state_list[i][-2] for i in range(batch_size)], dim=0 |
| ) |
| batch_states.append(cached_embed_left_pad) |
|
|
| processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) |
| batch_states.append(processed_lens) |
|
|
| return batch_states |
|
|
| def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: |
| """Unstack the zipformer state corresponding to a batch of utterances |
| into a list of states, where the i-th entry is the state from the i-th |
| utterance in the batch. |
| |
| Note: |
| It is the inverse of :func:`stack_states`. |
| |
| Args: |
| batch_states: A list of cached tensors of all encoder layers. For layer-i, |
| states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, |
| cached_conv1, cached_conv2). |
| state_list[-2] is the cached left padding for ConvNeXt module, |
| of shape (batch_size, num_channels, left_pad, num_freqs) |
| states[-1] is processed_lens of shape (batch,), which records the number |
| of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. |
| |
| Returns: |
| state_list: A list of list. Each element in state_list corresponding to the internal state |
| of the zipformer model for a single utterance. |
| """ |
| assert (len(batch_states) - 2) % 6 == 0, len(batch_states) |
| tot_num_layers = (len(batch_states) - 2) // 6 |
|
|
| processed_lens = batch_states[-1] |
| batch_size = processed_lens.shape[0] |
|
|
| state_list = [[] for _ in range(batch_size)] |
|
|
| for layer in range(tot_num_layers): |
| layer_offset = layer * 6 |
| |
| cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) |
| |
| cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( |
| chunks=batch_size, dim=1 |
| ) |
| |
| cached_val1_list = batch_states[layer_offset + 2].chunk( |
| chunks=batch_size, dim=1 |
| ) |
| |
| cached_val2_list = batch_states[layer_offset + 3].chunk( |
| chunks=batch_size, dim=1 |
| ) |
| |
| cached_conv1_list = batch_states[layer_offset + 4].chunk( |
| chunks=batch_size, dim=0 |
| ) |
| |
| cached_conv2_list = batch_states[layer_offset + 5].chunk( |
| chunks=batch_size, dim=0 |
| ) |
| for i in range(batch_size): |
| state_list[i] += [ |
| cached_key_list[i], |
| cached_nonlin_attn_list[i], |
| cached_val1_list[i], |
| cached_val2_list[i], |
| cached_conv1_list[i], |
| cached_conv2_list[i], |
| ] |
|
|
| cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) |
| for i in range(batch_size): |
| state_list[i].append(cached_embed_left_pad_list[i]) |
|
|
| processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) |
| for i in range(batch_size): |
| state_list[i].append(processed_lens_list[i]) |
|
|
| return state_list |
|
|
| def compute_fbank( |
| wavs: torch.Tensor, wav_lens: torch.Tensor |
| ): |
| """Compute fbank features |
| |
| Args: |
| wavs (torch.Tensor): the mono-channel input waveform, (N, T) |
| wav_lens (torch.Tensor): the length of each waveform in samples (N) |
| |
| Returns: |
| The fbank features, and their lengths |
| """ |
| assert wavs.ndim == 2, wavs.shape |
| low_freq = 20.0 |
| high_freq=-400.0 |
| dither=0.0 |
| snip_egdes=False |
|
|
| features = [] |
| for i, wav in enumerate(wavs): |
| feat = fbank( |
| wav[:wav_lens[i]].unsqueeze(0), |
| sample_frequency=16000, |
| num_mel_bins=128, |
| low_freq=low_freq, |
| snip_edges=snip_egdes, |
| high_freq=high_freq, |
| dither=dither, |
| energy_floor=1.0e-10, |
| ) |
| features.append(feat) |
| feat_len = torch.tensor([f.shape[0] for f in features]).to(wavs.device) |
| features = torch.nn.utils.rnn.pad_sequence( |
| features, batch_first=True, padding_value=LOG_EPS |
| ).to(wavs.device) |
| return features, feat_len |
|
|
|
|
| def streaming_forward( |
| features: Tensor, |
| feature_lens: Tensor, |
| model: nn.Module, |
| states: List[Tensor], |
| chunk_size: int, |
| left_context_len: int, |
| ) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]: |
| """ |
| Returns encoder outputs, output lengths, and updated states. |
| """ |
| cached_embed_left_pad = states[-2] |
| (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( |
| x=features, |
| x_lens=feature_lens, |
| cached_left_pad=cached_embed_left_pad, |
| ) |
| assert x.size(1) == chunk_size, (x.size(1), chunk_size) |
|
|
| src_key_padding_mask = make_pad_mask(x_lens) |
|
|
| |
| processed_mask = torch.arange(left_context_len, device=x.device).expand( |
| x.size(0), left_context_len |
| ) |
| processed_lens = states[-1] |
| |
| processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) |
| |
| new_processed_lens = processed_lens + x_lens |
|
|
| |
| src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) |
|
|
| x = x.permute(1, 0, 2) |
| encoder_states = states[:-2] |
| ( |
| encoder_out, |
| encoder_out_lens, |
| new_encoder_states, |
| middle_outs, |
| ) = model.encoder.streaming_forward( |
| x=x, |
| x_lens=x_lens, |
| states=encoder_states, |
| src_key_padding_mask=src_key_padding_mask, |
| ) |
| encoder_out = encoder_out.permute(1, 0, 2) |
| middle_outs = [m.permute(1, 0, 2) for m in middle_outs] |
|
|
| new_states = new_encoder_states + [ |
| new_cached_embed_left_pad, |
| new_processed_lens, |
| ] |
| return encoder_out, encoder_out_lens, new_states, middle_outs |
|
|
| def chunk_forward( |
| audio: torch.Tensor, |
| model: torch.nn.Module, |
| feature_dim: int = 128, |
| chunk_size: int = 8, |
| left_context_frames: int = 256, |
| ): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| device = next(model.parameters()).device |
| |
| chunk_size = int(chunk_size) |
| chunk_size_samples = int(chunk_size * 2 * 160) |
| left_context_len = int(left_context_frames) |
| |
| |
| |
| pad_length = 7 + 2 * 3 |
| pad_length_samples = (7 + 2 * 3) * 160 |
| |
| extra_tolerance = 0.01 |
| extra_tolerance_samples = int(extra_tolerance * 16000) |
| buffer_samples = pad_length_samples + extra_tolerance_samples |
| |
| chunk_size_with_pad = chunk_size * 2 + 7 + 2 * 3 |
| |
| |
| initial_states = get_init_states(model=model, batch_size=1, device=device) |
| encoder_outs = [] |
| middle_outs = [] |
| encoder_out_lens = 0 |
| states = initial_states |
| |
| num_chunk = 0 |
| num_processed_samples = 0 |
| |
| |
| while True: |
| |
| audio_chunk = audio[ |
| :, |
| num_processed_samples: num_processed_samples + (chunk_size_samples + buffer_samples) |
| ] |
| |
| |
| features, _ = compute_fbank(audio_chunk, torch.tensor([audio_chunk.shape[-1]])) |
| |
| features = features[:, :chunk_size_with_pad, :] |
| features = features.to(device) |
| feature_lens = features.shape[0] |
| feature_lens = torch.tensor([features.shape[1]], device=device) |
| |
| |
| |
| if features.size(1) < chunk_size_with_pad: |
| pad_length = chunk_size_with_pad - features.size(1) |
| feature_lens += pad_length |
| features = torch.nn.functional.pad( |
| features, |
| (0, 0, 0, pad_length), |
| mode="constant", |
| value=LOG_EPS, |
| ) |
| |
| states = stack_states([states]) |
| |
| |
| encoder_out, encoder_out_len, new_states, middle_out = streaming_forward( |
| features=features, |
| feature_lens=feature_lens, |
| model=model, |
| states=states, |
| chunk_size=chunk_size, |
| left_context_len=left_context_len, |
| ) |
| |
| encoder_outs.append(encoder_out) |
| middle_outs.append(middle_out) |
| encoder_out_lens += encoder_out_len |
| |
| |
| states = unstack_states(new_states)[0] |
| |
| num_chunk += 1 |
| num_processed_samples += chunk_size_samples |
| |
| if num_processed_samples > audio.shape[1]: |
| print(f"Audio is exhausted.") |
| break |
| |
| encoder_outs = torch.cat(encoder_outs, dim=1) |
| layerwise_outs = [] |
| for i in range(len(middle_outs[0])): |
| layerwise_outs.append(torch.cat([m[i] for m in middle_outs], dim=1)) |
|
|
| return encoder_outs, encoder_out_lens, layerwise_outs |
|
|
|
|
|
|
| def main(args): |
| device = torch.device("cpu") |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
|
|
| |
| model = get_model(args) |
| model.to(device) |
|
|
| info = model.load_state_dict( |
| torch.load(args.ckpt_path)["model"], strict=False |
| ) |
| print(info) |
| model.eval() |
|
|
| |
| audio, fs = torchaudio.load(args.audio) |
| assert fs == 16000 |
| |
| encoder_out, encoder_out_lens, intermediate_hidden_states = chunk_forward( |
| audio=audio, |
| model=model, |
| feature_dim=128, |
| chunk_size=args.chunk_size, |
| left_context_frames=args.left_context_frames, |
| ) |
|
|
| print(encoder_out) |
| print(encoder_out.shape) |
| print(intermediate_hidden_states[-1]) |
| print(intermediate_hidden_states[-1].shape) |
| |
|
|
| if __name__=="__main__": |
| parser = get_parser() |
| args = parser.parse_args() |
|
|
| main(args) |