| import torch |
| from torch import nn |
| from safe_concat import * |
| from nemo.collections.asr.models import SortformerEncLabelModel |
|
|
|
|
| def fixed_concat_and_pad(embs, lengths, max_total_len=188+188+6): |
| """ |
| ANE-safe concat and pad that avoids zero-length slices. |
| |
| Uses gather with arithmetic-computed indices to pack valid frames efficiently. |
| |
| Args: |
| embs: List of 3 tensors [spkcache, fifo, chunk], each (B, seq_len, D) |
| lengths: List of 3 length tensors, each (1,) or scalar |
| First two may be 0, third is always > 0 |
| max_total_len: Output sequence length (padded with zeros) |
| |
| Returns: |
| output: (B, max_total_len, D) with valid frames packed at the start |
| total_length: sum of lengths |
| """ |
| B, _, D = embs[0].shape |
| device = embs[0].device |
| |
| |
| size0, size1, size2 = embs[0].shape[1], embs[1].shape[1], embs[2].shape[1] |
| total_input_size = size0 + size1 + size2 |
| |
| |
| full_concat = torch.cat(embs, dim=1) |
| |
| |
| len0 = lengths[0].reshape(()) |
| len1 = lengths[1].reshape(()) |
| len2 = lengths[2].reshape(()) |
| total_length = len0 + len1 + len2 |
| |
| |
| out_pos = torch.arange(max_total_len, device=device, dtype=torch.long) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| cumsum0 = len0 |
| cumsum1 = len0 + len1 |
| |
| |
| in_seg1_or_2 = (out_pos >= cumsum0).long() |
| in_seg2 = (out_pos >= cumsum1).long() |
| |
| |
| offset = in_seg1_or_2 * (size0 - len0) + in_seg2 * (size1 - len1) |
| gather_idx = (out_pos + offset).clamp(0, total_input_size - 1) |
| |
| |
| gather_idx = gather_idx.unsqueeze(0).unsqueeze(-1).expand(B, max_total_len, D) |
| |
| |
| output = torch.gather(full_concat, dim=1, index=gather_idx) |
| output = output * (out_pos < total_length).float().unsqueeze(0).unsqueeze(-1) |
| |
| return output, total_length |
|
|
|
|
| class PreprocessorWrapper(nn.Module): |
| """ |
| Wraps the NeMo preprocessor (FilterbankFeaturesTA) for CoreML export. |
| We need to ensure it takes (audio, length) and returns (features, length). |
| """ |
|
|
| def __init__(self, preprocessor): |
| super().__init__() |
| self.preprocessor = preprocessor |
|
|
| def forward(self, audio_signal, length): |
| |
| |
| return self.preprocessor(input_signal=audio_signal, length=length) |
|
|
|
|
| class SortformerHeadWrapper(nn.Module): |
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, pre_encoder_embs, pre_encoder_lengths, chunk_pre_encoder_embs, chunk_pre_encoder_lengths): |
| spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths = self.model.frontend_encoder( |
| processed_signal=pre_encoder_embs, |
| processed_signal_length=pre_encoder_lengths, |
| bypass_pre_encode=True, |
| ) |
|
|
| |
| spkcache_fifo_chunk_preds = self.model.forward_infer( |
| spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths |
| ) |
| return spkcache_fifo_chunk_preds, chunk_pre_encoder_embs, chunk_pre_encoder_lengths |
|
|
|
|
| class SortformerCoreMLWrapper(nn.Module): |
| """ |
| Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export) |
| The 'forward_for_export' method in the model is the target. |
| """ |
|
|
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
| self.pre_encoder = PreEncoderWrapper(model) |
|
|
| def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths): |
| (spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths, |
| chunk_pre_encode_embs, chunk_pre_encode_lengths) = self.pre_encoder( |
| chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths |
| ) |
|
|
| |
| spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths = self.model.frontend_encoder( |
| processed_signal=spkcache_fifo_chunk_pre_encode_embs, |
| processed_signal_length=spkcache_fifo_chunk_pre_encode_lengths, |
| bypass_pre_encode=True, |
| ) |
|
|
| |
| spkcache_fifo_chunk_preds = self.model.forward_infer( |
| spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths |
| ) |
| return spkcache_fifo_chunk_preds, chunk_pre_encode_embs, chunk_pre_encode_lengths |
|
|
|
|
| class PreEncoderWrapper(nn.Module): |
| """ |
| Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export) |
| The 'forward_for_export' method in the model is the target. |
| """ |
|
|
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
| modules = model.sortformer_modules |
| chunk_length = modules.chunk_left_context + modules.chunk_len + modules.chunk_right_context |
| self.pre_encoder_length = modules.spkcache_len + modules.fifo_len + chunk_length |
|
|
| def forward(self, *args): |
| if len(args) == 6: |
| return self.forward_concat(*args) |
| else: |
| return self.forward_pre_encode(*args) |
|
|
| def forward_concat(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths): |
| chunk_pre_encode_embs, chunk_pre_encode_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths) |
| chunk_pre_encode_lengths = chunk_pre_encode_lengths.to(torch.int64) |
| spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths = fixed_concat_and_pad( |
| [spkcache, fifo, chunk_pre_encode_embs], |
| [spkcache_lengths, fifo_lengths, chunk_pre_encode_lengths], |
| self.pre_encoder_length |
| ) |
| return (spkcache_fifo_chunk_pre_encode_embs, spkcache_fifo_chunk_pre_encode_lengths, |
| chunk_pre_encode_embs, chunk_pre_encode_lengths) |
|
|
| def forward_pre_encode(self, chunk, chunk_lengths): |
| chunk_pre_encode_embs, chunk_pre_encode_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths) |
| chunk_pre_encode_lengths = chunk_pre_encode_lengths.to(torch.int64) |
|
|
| return chunk_pre_encode_embs, chunk_pre_encode_lengths |
|
|
|
|
| class ConformerEncoderWrapper(nn.Module): |
| """ |
| Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export) |
| The 'forward_for_export' method in the model is the target. |
| """ |
|
|
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, pre_encode_embs, pre_encode_lengths): |
| spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths = self.model.frontend_encoder( |
| processed_signal=pre_encode_embs, |
| processed_signal_length=pre_encode_lengths, |
| bypass_pre_encode=True, |
| ) |
| return spkcache_fifo_chunk_fc_encoder_embs, spkcache_fifo_chunk_fc_encoder_lengths |
|
|
|
|
| class SortformerEncoderWrapper(nn.Module): |
| """ |
| Wraps the entire Sortformer pipeline (Encoder + Streaming Logic for Export) |
| The 'forward_for_export' method in the model is the target. |
| """ |
|
|
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, encoder_embs, encoder_lengths): |
| spkcache_fifo_chunk_preds = self.model.forward_infer( |
| encoder_embs, encoder_lengths |
| ) |
| return spkcache_fifo_chunk_preds |
|
|