| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import math |
| from typing import Optional, Tuple |
| import random |
|
|
| import numpy as np |
| import torch |
| import torchaudio |
| import torch.nn as nn |
| from torch.nn.utils.rnn import pad_sequence |
| from torchaudio.compliance.kaldi import fbank as torch_fbank |
|
|
| from .configuration_spear import SpearConfig |
| from .zipformer import Zipformer2, Conv2dSubsampling |
|
|
| LOG_EPS=math.log(1e-10) |
| SAMPLING_RATE=16000 |
|
|
| def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: |
| """ |
| Args: |
| lengths: |
| A 1-D tensor containing sentence lengths. |
| max_len: |
| The length of masks. |
| Returns: |
| Return a 2-D bool tensor, where masked positions |
| are filled with `True` and non-masked positions are |
| filled with `False`. |
| |
| This function is borrowed from https://github.com/k2-fsa/icefall |
| |
| >>> lengths = torch.tensor([1, 3, 2, 5]) |
| >>> make_pad_mask(lengths) |
| tensor([[False, True, True, True, True], |
| [False, False, False, True, True], |
| [False, False, True, True, True], |
| [False, False, False, False, False]]) |
| """ |
| assert lengths.ndim == 1, lengths.ndim |
| max_len = max(max_len, lengths.max()) |
| n = lengths.size(0) |
| seq_range = torch.arange(0, max_len, device=lengths.device) |
| expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) |
|
|
| return expaned_lengths >= lengths.unsqueeze(-1) |
|
|
| def get_model(config: SpearConfig) -> nn.Module: |
| encoder_embed = get_encoder_embed(config) |
| encoder = get_encoder_model(config) |
|
|
| model = SpearEncoder( |
| encoder_embed=encoder_embed, |
| encoder=encoder, |
| encoder_dim=max(_to_int_tuple(config.encoder_dim)), |
| num_codebooks=0, |
| ) |
|
|
| return model |
|
|
| class SpearModel(nn.Module): |
| def __init__( |
| self, config: SpearConfig, |
| ): |
| super().__init__() |
| model = get_model(config) |
| self.config = config |
| self.model = model |
|
|
| def _load_audio_single(self, audio_path: str) -> Tuple[torch.Tensor, int]: |
| waveform, sr = torchaudio.load(audio_path) |
| if waveform.size(0) > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| if sr != SAMPLING_RATE: |
| transform = torchaudio.transforms.Resample(sr, SAMPLING_RATE) |
| waveform = transform(waveform) |
| waveform_len = waveform.shape[-1] |
| return waveform, waveform_len |
| |
| def load_audio(self, audio_paths: list[str]) -> Tuple[torch.Tensor, torch.Tensor]: |
| assert isinstance(audio_paths, list), "Must receive a list of files for reading" |
| waveforms = [] |
| waveform_lens = [] |
| for audio in audio_paths: |
| wav, wav_len = self._load_audio_single(audio) |
| waveforms.append(wav.squeeze()) |
| waveform_lens.append(wav_len) |
| |
| waveforms = pad_sequence(waveforms, batch_first=True) |
| waveform_lens = torch.tensor(waveform_lens) |
| return waveforms, waveform_lens |
| |
| def compute_fbank( |
| self, wavs: torch.Tensor, wav_lens: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Compute fbank features |
| |
| Args: |
| wavs (torch.Tensor): the mono-channel input waveform, (N, T) |
| wav_lens (torch.Tensor): the length of each waveform in samples (N) |
| |
| Returns: |
| The fbank features, and their lengths |
| """ |
| assert wavs.ndim == 2, wavs.shape |
| low_freq = 20.0 |
| high_freq=-400.0 |
| dither=0.0 |
| snip_egdes=False |
|
|
| features = [] |
| for i, wav in enumerate(wavs): |
| feat = torch_fbank( |
| wav[:wav_lens[i]].unsqueeze(0), |
| sample_frequency=16000, |
| num_mel_bins=128, |
| low_freq=low_freq, |
| snip_edges=snip_egdes, |
| high_freq=high_freq, |
| dither=dither, |
| energy_floor=1.0e-10, |
| ) |
| features.append(feat) |
| feat_len = torch.tensor([f.shape[0] for f in features]).to(wavs.device) |
| features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS).to(wavs.device) |
| return features, feat_len |
| |
| |
| def forward(self, audio: torch.Tensor, audio_lens: torch.Tensor, return_middle_layers: bool = True): |
| """Encode a batch of audio |
| |
| Args: |
| audio (torch.Tensor): Input audio waveforms (N,L) |
| audio_lens (torch.Tensor): The length of the audio waveforms (N) |
| return_middle_layers (bool, optional): Output the intermediate features. |
| |
| Returns: |
| The encoded representations, and the length of each representation (N,T,C), (N) |
| """ |
| |
| |
| x, x_lens = self.compute_fbank(audio, audio_lens) |
| outputs = self.model.forward_encoder( |
| x=x, |
| x_lens=x_lens, |
| return_middle_out=return_middle_layers, |
| return_dict=True, |
| ) |
| return outputs |
| |
|
|
| class SpearEncoder(nn.Module): |
| def __init__( |
| self, |
| encoder_embed: nn.Module, |
| encoder: nn.Module, |
| encoder_dim: int, |
| num_codebooks: int=8, |
| distillation_layer: int=9, |
| distillation_delta: int=0, |
| teacher_frame_ratio: int = 2, |
| interpolate_teacher: bool = False, |
| n_mels: int = 128, |
| mask_mode: str = "w2v2", |
| mask_prob: float = 0.65, |
| mask_length: int = 10, |
| mask_selection: str = "static", |
| mask_other: float = 0.0, |
| min_masks: int = 2, |
| mask_channel_prob: float = 0.0, |
| mask_channel_length: int = 10, |
| mask_channel_selection: str = "static", |
| mask_channel_other: float = 0.0, |
| loss_only_mask: bool = False, |
| ): |
| """A model that performs MVQ KD pre-training . |
| |
| Args: |
| encoder_embed: |
| It is a Convolutional 2D subsampling module. It converts |
| an input of shape (N, T, idim) to an output of of shape |
| (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. |
| encoder: |
| It is the transcription network in the paper. Its accepts |
| two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). |
| It returns two tensors: `logits` of shape (N, T, encoder_dim) and |
| `logit_lens` of shape (N,). |
| num_codebooks: |
| The number of codebooks used in the target |
| distillation_layer: |
| Use which layer to do MVQ pre-training |
| distillation_delta: |
| How many frames to delay the alignment between the model and the target frames. |
| Should be zero for non-streaming models, and a positive number for streaming models |
| teacher_frame_ratio: |
| The frame rate ratio between the target and the model output |
| mask_mode: |
| The masking mode. |
| w2v2: the wav2vec2 style of masking, allows overlap |
| custom: no overlap, therefore bigger masking ratio |
| mask_prob: |
| The probability of selecting choosing one frame as the start index |
| mask_length: |
| The length of each mask |
| mask_selection: |
| How to determine the length of the mask, see ``compute_mask_indices'' |
| """ |
| super().__init__() |
| |
| self.encoder_embed = encoder_embed |
| self.encoder = encoder |
| self.encoder_dim = encoder_dim |
| |
| self.distillation_layer = distillation_layer |
| |
| |
| |
| self.num_codebooks= num_codebooks |
| self.teacher_frame_ratio = teacher_frame_ratio |
| self.interpolate_teacher = interpolate_teacher |
| self.distillation_delta = distillation_delta |
| |
| if num_codebooks > 0: |
| from .spear_modules import JointCodebookLoss |
| self.codebook_loss_net = JointCodebookLoss( |
| input_dim=encoder_dim, |
| num_codebooks=num_codebooks * self.teacher_frame_ratio, |
| reduction="none", |
| ) |
| else: |
| self.codebook_loss_net = None |
| |
| |
| assert mask_mode in ["w2v2", "block"], f"Unseen mask mode: {mask_mode}" |
| self.mask_mode = mask_mode |
| |
| self.mask_emb = nn.Parameter(torch.FloatTensor(n_mels).normal_()) |
| self.mask_prob = mask_prob |
| self.mask_length = mask_length |
| self.mask_selection = mask_selection |
| self.mask_other = mask_other |
| self.min_masks = min_masks |
| |
| self.mask_channel_prob = mask_channel_prob |
| self.mask_channel_length = mask_channel_length |
| self.mask_channel_selection = mask_channel_selection |
| self.mask_channel_other = mask_channel_other |
| |
| self.loss_only_mask = loss_only_mask |
|
|
| def forward_encoder( |
| self, x: torch.Tensor, x_lens: torch.Tensor, return_middle_out: bool = False, return_dict: bool = False, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Compute encoder outputs. |
| Args: |
| x: |
| A 3-D tensor of shape (N, T, C). |
| x_lens: |
| A 1-D tensor of shape (N,). It contains the number of frames in `x` |
| before padding. |
| |
| Returns: |
| encoder_out: |
| Encoder output, of shape (N, T, C). |
| encoder_out_lens: |
| Encoder output lengths, of shape (N,). |
| """ |
| |
| x, x_lens = self.encoder_embed(x, x_lens) |
| |
|
|
| src_key_padding_mask = make_pad_mask(x_lens) |
| x = x.permute(1, 0, 2) |
|
|
| encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask, return_middle_out=True) |
| middle_out = [feat.permute(1,0,2) for feat in middle_out] |
|
|
| encoder_out = encoder_out.permute(1, 0, 2) |
| assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) |
|
|
| if not return_dict: |
| return encoder_out, encoder_out_lens, middle_out |
| else: |
| outputs = { |
| "encoder_out": encoder_out, |
| "encoder_out_lens": encoder_out_lens, |
| "hidden_states": middle_out, |
| } |
| return outputs |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| x_lens: torch.Tensor, |
| codebook_indexes: torch.Tensor = None, |
| mask: bool = True, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| x: |
| A 3-D tensor of shape (N, T, C). |
| x_lens: |
| A 1-D tensor of shape (N,). It contains the number of frames in `x` |
| before padding. |
| codebook_indexes: |
| Codebook indexes of teacher embeddings |
| mask: |
| If we perform w2v2 style of masking over the fbank frames |
| |
| Returns: |
| Return the codebook loss |
| """ |
| assert x.ndim == 3, x.shape |
| assert x_lens.ndim == 1, x_lens.shape |
| assert codebook_indexes is not None |
|
|
| |
| if self.training and mask: |
| padding_mask = make_pad_mask(x_lens) |
| |
| |
| x, mask_indices = self.apply_mask( |
| x.clone(), |
| padding_mask=padding_mask |
| ) |
| else: |
| mask_indices = None |
| |
| |
| encoder_out, encoder_out_lens, _ = self.forward_encoder(x, x_lens) |
| |
| |
| if codebook_indexes is not None and self.codebook_loss_net is not None: |
| codebook_loss = self.forward_codebook_loss( |
| encoder_out, encoder_out_lens, codebook_indexes, reduction="none" |
| ) |
| if self.loss_only_mask and mask_indices is not None: |
| |
| mask_indices = nn.functional.avg_pool1d(mask_indices, 4) >= 0.5 |
| assert mask_indices.size(1) >= codebook_loss.size(1) |
| mask_indices = mask_indices[:, :codebook_loss.size(1)].float() |
| codebook_loss = codebook_loss * mask_indices |
| codebook_loss = codebook_loss.sum(dim=1) |
| else: |
| codebook_loss = None |
| |
| return codebook_loss |
|
|
| def forward_codebook_loss( |
| self, |
| encoder_out: torch.Tensor, |
| encoder_out_lens: torch.Tensor, |
| codebook_indexes: torch.Tensor, |
| reduction: str = "sum", |
| ): |
| |
| if self.interpolate_teacher: |
| codebook_indexes = self.interpolate_codebook_indexes( |
| encoder_out, codebook_indexes |
| ) |
| else: |
| if codebook_indexes.shape[1] != encoder_out.shape[1]: |
| |
| codebook_indexes = self.concat_successive_codebook_indexes( |
| encoder_out, codebook_indexes, ratio=self.teacher_frame_ratio |
| ) |
| |
| |
| |
| if self.distillation_delta > 0: |
| codebook_indexes = codebook_indexes[:,:-self.distillation_delta, :] |
| encoder_out = encoder_out[:, self.distillation_delta:, :] |
| truncated_padding_mask = make_pad_mask(encoder_out_lens - self.distillation_delta) |
| codebook_indexes = codebook_indexes.masked_fill(truncated_padding_mask.unsqueeze(-1), value=-100) |
| |
| N,T,_ = encoder_out.shape |
| codebook_loss = self.codebook_loss_net(encoder_out.float(), codebook_indexes) |
| codebook_loss = codebook_loss.reshape(N,T,-1) |
| num_cb = codebook_loss.size(-1) |
| |
| if reduction == "sum": |
| codebook_loss = codebook_loss.sum(dim=(1,2)) / num_cb |
| elif reduction == "none": |
| codebook_loss = codebook_loss.sum(dim=2) / num_cb |
| else: |
| raise NotImplementedError() |
| |
| return codebook_loss |
| |
| def apply_mask( |
| self, |
| x: torch.Tensor, |
| padding_mask: torch.Tensor = None |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Apply mask according to the mask_mode, return the masked features and the masked positions |
| |
| Args: |
| x (torch.Tensor): The input fbank features |
| padding_mask (torch.Tensor, optional): The padding mask |
| |
| Returns: |
| The masked fbank feature and the masked_indices, with masked positions as 1 |
| """ |
| |
| if self.mask_mode == "w2v2": |
| x, masked_indices = self.apply_mask_w2v2(x, padding_mask) |
| elif self.mask_mode == "block": |
| x, masked_indices = self.apply_mask_block(x, padding_mask) |
| else: |
| raise NotImplementedError() |
| |
| if random.random() > 0.97: |
| logging.info(f"Apply {self.mask_mode} masking. A proportion of {masked_indices.sum()/masked_indices.numel():.2f} frames are masked") |
| return x, masked_indices |
| |
| |
| def apply_mask_block( |
| self, |
| x: torch.Tensor, |
| padding_mask: torch.Tensor = None |
| ): |
| B,T,C = x.shape |
| assert self.mask_prob > 0.0 |
|
|
| mask_indices = compute_mask_indices_block( |
| shape=(B,T), |
| padding_mask=padding_mask, |
| mask_prob=self.mask_prob, |
| mask_length=self.mask_length, |
| min_masks=self.min_masks, |
| ).to(x.device) |
| |
| x = index_put(x, mask_indices.bool(), self.mask_emb) |
|
|
| return x, mask_indices |
| |
| def apply_mask_w2v2( |
| self, |
| x: torch.Tensor, |
| padding_mask: torch.Tensor = None |
| ): |
| |
| |
| B, T, C = x.shape |
| |
| |
| if self.mask_channel_prob > 0: |
| mask_channel_indices = compute_mask_indices( |
| (B, C), |
| None, |
| self.mask_channel_prob, |
| self.mask_channel_length, |
| self.mask_channel_selection, |
| self.mask_channel_other, |
| no_overlap=False, |
| min_space=1, |
| require_same_masks=False, |
| ) |
| mask_channel_indices = ( |
| torch.from_numpy(mask_channel_indices) |
| .to(x.device) |
| .unsqueeze(1) |
| .expand(-1, T, -1) |
| ) |
| if random.random() > 0.98: |
| logging.info(f"A proportion of {mask_channel_indices.sum()/mask_channel_indices.numel():.2f} feature dims are masked") |
| x[mask_channel_indices] = 0 |
|
|
| if self.mask_prob > 0: |
| mask_indices = compute_mask_indices( |
| (B, T), |
| padding_mask, |
| self.mask_prob, |
| self.mask_length, |
| mask_type=self.mask_selection, |
| mask_other=self.mask_other, |
| min_masks=2, |
| no_overlap=False, |
| min_space=1, |
| require_same_masks=False, |
| ) |
| mask_indices = torch.from_numpy(mask_indices).to(x.device) |
| x = index_put(x, mask_indices, self.mask_emb) |
| mask_indices = mask_indices.float() |
| else: |
| mask_indices = None |
|
|
| return x, mask_indices |
| |
| @staticmethod |
| def interpolate_codebook_indexes(middle_layer_output, codebook_indexes): |
| |
| |
| t_expected = middle_layer_output.shape[1] |
| N, T, C = codebook_indexes.shape |
| |
| codebook_indexes = codebook_indexes.permute(0,2,1).float() |
| codebook_indexes = torch.nn.functional.interpolate(codebook_indexes, t_expected) |
| codebook_indexes = codebook_indexes.permute(0,2,1).int() |
| |
| assert codebook_indexes.shape[1] == middle_layer_output.shape[1] |
| return codebook_indexes |
| |
| @staticmethod |
| def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes, ratio=2): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| t_expected = middle_layer_output.shape[1] |
| N, T, C = codebook_indexes.shape |
| |
| |
| if T >= t_expected * ratio: |
| codebook_indexes = codebook_indexes[:, : t_expected * ratio, :] |
| else: |
| assert t_expected * ratio - T <= 5, (T, t_expected, ratio) |
| diff = t_expected * ratio - T |
| codebook_indexes = torch.cat( |
| [ |
| codebook_indexes, |
| torch.full((N,diff,C), -100).to(codebook_indexes.device).to(codebook_indexes.dtype) |
| ], |
| dim=1, |
| ) |
| assert codebook_indexes.size(1) == middle_layer_output.size(1) * ratio |
| |
| |
| codebook_indexes = codebook_indexes.reshape(N, t_expected, C * ratio) |
| assert middle_layer_output.shape[1] == codebook_indexes.shape[1] |
| return codebook_indexes |
| |
| def index_put(tensor, indices, value): |
| tensor[indices] = value |
| return tensor |
|
|
| def compute_mask_indices_block( |
| shape, |
| padding_mask, |
| mask_prob: float = 0.5, |
| mask_length: int = 10, |
| min_masks: int = 2, |
| ): |
| |
| B,T = shape |
| mask_indices = [] |
| for i in range(B): |
| if padding_mask is not None: |
| num_segments = (T - padding_mask[i].sum()) // mask_length |
| else: |
| num_segments = T // mask_length |
| segment_mask = torch.rand(num_segments) < mask_prob |
| while sum(segment_mask) < min_masks: |
| segment_mask = torch.rand(num_segments) < mask_prob |
| segment_mask_expanded = segment_mask.unsqueeze(-1).expand(num_segments, mask_length) |
| segment_mask_expanded = segment_mask_expanded.reshape(-1).float() |
| if segment_mask_expanded.size(0) < T: |
| pad = T - segment_mask_expanded.size(0) |
| segment_mask_expanded = torch.cat([segment_mask_expanded, torch.zeros(pad)]) |
| mask_indices.append(segment_mask_expanded) |
|
|
| mask_indices = torch.stack(mask_indices) |
| return mask_indices |
|
|
| def compute_mask_indices( |
| shape: Tuple[int, int], |
| padding_mask: Optional[torch.Tensor], |
| mask_prob: float, |
| mask_length: int, |
| mask_type: str = "static", |
| mask_other: float = 0.0, |
| min_masks: int = 0, |
| no_overlap: bool = False, |
| min_space: int = 0, |
| require_same_masks: bool = True, |
| mask_dropout: float = 0.0, |
| add_masks: bool = False, |
| seed: Optional[int] = None, |
| epoch: Optional[int] = None, |
| indices: Optional[torch.Tensor] = None, |
| idc_select_ver: int = 1, |
| num_mask_ver: int = 2, |
| ) -> np.ndarray: |
| """ |
| Computes random mask spans for a given shape |
| |
| Args: |
| shape: the the shape for which to compute masks. |
| should be of size 2 where first element is batch size and 2nd is timesteps |
| padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements |
| mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by |
| number of timesteps divided by length of mask span to mask approximately this percentage of all elements. |
| however due to overlaps, the actual number will be smaller (unless no_overlap is True) |
| mask_type: how to compute mask lengths |
| static = fixed size |
| uniform = sample from uniform distribution [mask_other, mask_length*2] |
| normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element |
| poisson = sample from possion distribution with lambda = mask length |
| min_masks: minimum number of masked spans |
| no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping |
| min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans |
| require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample |
| mask_dropout: randomly dropout this percentage of masks in each example |
| """ |
|
|
| bsz, all_sz = shape |
| mask = np.full((bsz, all_sz), False) |
|
|
| if num_mask_ver == 1: |
| all_num_mask = int( |
| |
| mask_prob * all_sz / float(mask_length) |
| + np.random.rand() |
| ) |
| all_num_mask = max(min_masks, all_num_mask) |
|
|
| mask_idcs = [] |
| for i in range(bsz): |
| if seed is not None and epoch is not None and indices is not None: |
| seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) |
| else: |
| seed_i = None |
|
|
| rng = np.random.default_rng(seed_i) |
|
|
| if padding_mask is not None: |
| sz = all_sz - padding_mask[i].long().sum().item() |
| assert sz >= 0, sz |
| else: |
| sz = all_sz |
|
|
| if num_mask_ver == 1: |
| if padding_mask is not None: |
| num_mask = int( |
| |
| mask_prob * sz / float(mask_length) |
| + np.random.rand() |
| ) |
| num_mask = max(min_masks, num_mask) |
| else: |
| num_mask = all_num_mask |
| elif num_mask_ver == 2: |
| num_mask = int( |
| |
| mask_prob * sz / float(mask_length) |
| + rng.random() |
| ) |
| num_mask = max(min_masks, num_mask) |
| hard_max = sz // mask_length |
| num_mask = min(hard_max, num_mask) |
| else: |
| raise ValueError() |
|
|
| if mask_type == "static": |
| lengths = np.full(num_mask, mask_length) |
| elif mask_type == "uniform": |
| lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) |
| elif mask_type == "normal": |
| lengths = rng.normal(mask_length, mask_other, size=num_mask) |
| lengths = [max(1, int(round(x))) for x in lengths] |
| elif mask_type == "poisson": |
| lengths = rng.poisson(mask_length, size=num_mask) |
| lengths = [int(round(x)) for x in lengths] |
| else: |
| raise Exception("unknown mask selection " + mask_type) |
|
|
| if sum(lengths) == 0: |
| if mask_type == "static": |
| raise ValueError("this should never happens") |
| else: |
| lengths = [min(mask_length, sz - 1)] |
|
|
| if no_overlap: |
| mask_idc = [] |
|
|
| def arrange(s, e, length, keep_length): |
| span_start = rng.randint(s, e - length) |
| mask_idc.extend(span_start + i for i in range(length)) |
|
|
| new_parts = [] |
| if span_start - s - min_space >= keep_length: |
| new_parts.append((s, span_start - min_space + 1)) |
| if e - span_start - length - min_space > keep_length: |
| new_parts.append((span_start + length + min_space, e)) |
| return new_parts |
|
|
| parts = [(0, sz)] |
| min_length = min(lengths) |
| for length in sorted(lengths, reverse=True): |
| lens = np.fromiter( |
| (e - s if e - s >= length + min_space else 0 for s, e in parts), |
| np.int, |
| ) |
| l_sum = np.sum(lens) |
| if l_sum == 0: |
| break |
| probs = lens / np.sum(lens) |
| c = rng.choice(len(parts), p=probs) |
| s, e = parts.pop(c) |
| parts.extend(arrange(s, e, length, min_length)) |
| mask_idc = np.asarray(mask_idc) |
| else: |
| if idc_select_ver == 1: |
| min_len = min(lengths) |
| if sz - min_len <= num_mask: |
| min_len = sz - num_mask - 1 |
| mask_idc = rng.choice(sz - min_len, num_mask, replace=False) |
| elif idc_select_ver == 2: |
| mask_idc = rng.choice(sz, num_mask, replace=False) |
| else: |
| raise ValueError() |
|
|
| mask_idc = np.asarray( |
| [ |
| mask_idc[j] + offset |
| for j in range(len(mask_idc)) |
| for offset in range(lengths[j]) |
| ] |
| ) |
|
|
| mask_idc = np.unique(mask_idc[mask_idc < sz]) |
| if len(mask_idc) >= sz: |
| |
| raise ValueError( |
| ( |
| f"the entire sequence is masked. " |
| f"sz={sz}; mask_idc[mask_idc]; " |
| f"index={indices[i] if indices is not None else None}" |
| ) |
| ) |
| mask_idcs.append(mask_idc) |
|
|
| target_len = None |
| if require_same_masks: |
| if add_masks: |
| target_len = max([len(m) for m in mask_idcs]) |
| else: |
| target_len = min([len(m) for m in mask_idcs]) |
|
|
| for i, mask_idc in enumerate(mask_idcs): |
| if target_len is not None and len(mask_idc) > target_len: |
| mask_idc = rng.choice(mask_idc, target_len, replace=False) |
|
|
| mask[i, mask_idc] = True |
|
|
| if target_len is not None and len(mask_idc) < target_len: |
| unmasked = np.flatnonzero(~mask[i]) |
| to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) |
| mask[i, to_mask] = True |
|
|
| if mask_dropout > 0: |
| masked = np.flatnonzero(mask[i]) |
| num_holes = np.rint(len(masked) * mask_dropout).astype(int) |
| to_drop = rng.choice(masked, num_holes, replace=False) |
| mask[i, to_drop] = False |
|
|
| return mask |
|
|
| def _to_int_tuple(s: str): |
| return tuple(map(int, s.split(","))) |
|
|
| def get_encoder_embed(config: SpearConfig) -> nn.Module: |
| |
| encoder_embed = Conv2dSubsampling( |
| in_channels=config.num_mel_bins, |
| out_channels=_to_int_tuple(config.encoder_dim)[0], |
| ) |
| return encoder_embed |
|
|
| def get_encoder_model(config: SpearConfig) -> nn.Module: |
| |
| encoder = Zipformer2( |
| output_downsampling_factor=config.output_downsampling_factor, |
| downsampling_factor=_to_int_tuple(config.downsampling_factor), |
| num_encoder_layers=_to_int_tuple(config.num_encoder_layers), |
| encoder_dim=_to_int_tuple(config.encoder_dim), |
| encoder_unmasked_dim=_to_int_tuple(config.encoder_unmasked_dim), |
| query_head_dim=_to_int_tuple("32"), |
| pos_head_dim=_to_int_tuple("4"), |
| value_head_dim=_to_int_tuple("12"), |
| pos_dim=config.pos_dim, |
| num_heads=_to_int_tuple(config.num_heads), |
| feedforward_dim=_to_int_tuple(config.feedforward_dim), |
| cnn_module_kernel=_to_int_tuple(config.cnn_module_kernel), |
| warmup_batches=4000.0, |
| causal=config.causal, |
| chunk_size=config.chunk_size, |
| left_context_frames=config.left_context_frames, |
| ) |
| return encoder |
|
|
|
|
| def _test_w2v2_channel_mask(): |
| x = torch.ones(100, 1000, 128) |
| B, T, C = x.shape |
| |
| configs = [(0.25, 15), (0.25, 20), (0.5, 15),] |
| |
| for config in configs: |
| mask_channel_prob, mask_channel_length = config |
| ratios = [] |
| for i in range(20): |
| mask_channel_indices = compute_mask_indices( |
| (B, C), |
| None, |
| mask_channel_prob, |
| mask_channel_length, |
| "static", |
| 0.0, |
| no_overlap=False, |
| min_space=1, |
| require_same_masks=False, |
| ) |
| mask_channel_indices = ( |
| torch.from_numpy(mask_channel_indices) |
| .to(x.device) |
| .unsqueeze(1) |
| .expand(-1, T, -1) |
| ) |
| ratio = mask_channel_indices.sum() / mask_channel_indices.numel() |
| ratios.append(ratio) |
| avg_ratio = sum(ratios) / len(ratios) |
| print(f"Current config: mask_channel_prob = {mask_channel_prob}, mask_channel_length = {mask_channel_length}") |
| print(f"Averaged masking ratio: {avg_ratio}") |
|
|
| def _test_w2v2_mask(): |
| x = torch.ones(100, 1000, 128) |
| B, T, C = x.shape |
| |
| mask_prob = 0.65 |
| mask_length = 10 |
| |
| |
| configs = [] |
| for i in range(6): |
| p = 0.05 + (i+1) * 0.1 |
| for l in [10, 20, 30, 40]: |
| configs.append((p, l)) |
| configs = [(0.65, 10), (0.02, 40), (0.05, 40), (0.1, 40)] |
| for config in configs: |
| mask_prob, mask_length = config |
| ratios = [] |
| for i in range(20): |
| mask_indices = compute_mask_indices( |
| (B, T), |
| None, |
| mask_prob, |
| mask_length, |
| mask_type="static", |
| mask_other=0.0, |
| min_masks=2, |
| no_overlap=False, |
| min_space=1, |
| require_same_masks=False, |
| ) |
| mask_indices = torch.from_numpy(mask_indices) |
| ratio = mask_indices.sum() / mask_indices.numel() |
| ratios.append(ratio) |
| avg_ratio = sum(ratios) / len(ratios) |
| print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}") |
| print(f"Averaged masking ratio: {avg_ratio}") |
|
|
| def _test_custom_mask(): |
| x = torch.ones(100, 1000, 128) |
| B, T, C = x.shape |
| |
| configs = [(0.5, 20), (0.2, 20), (0.3, 20), (0.4, 20), (0.5, 20)] |
| for config in configs: |
| mask_prob, mask_length = config |
| ratios = [] |
| for i in range(20): |
| all_possible_mask_lengths = [mask_length + i * 2 for i in range(-5, 6)] |
| mask_length = random.sample(all_possible_mask_lengths, 1)[0] |
| assert mask_length > 0, f"Sampled mask_length smaller than 0, {mask_length}" |
| |
| mask_indices = compute_mask_indices_block( |
| shape=(B, T), |
| padding_mask=None, |
| mask_prob=mask_prob, |
| mask_length=mask_length, |
| min_masks=2, |
| ) |
| ratio = mask_indices.sum() / mask_indices.numel() |
| ratios.append(ratio) |
| avg_ratio = sum(ratios) / len(ratios) |
| print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}") |
| print(f"Averaged masking ratio: {avg_ratio}") |
| |
|
|
| if __name__=="__main__": |
| _test_w2v2_channel_mask() |
| _test_w2v2_mask() |
| _test_custom_mask() |