| from torchaudio.models import Conformer |
| from torchaudio.models.rnnt import _TimeReduction |
| from transformers import PretrainedConfig, PreTrainedModel |
| import torch |
| from torch import nn |
| from typing import List, Tuple, Optional |
|
|
|
|
| class ConformerConfig(PretrainedConfig): |
| model_type = 'conformer' |
|
|
|
|
| class ConformerEncoder(PreTrainedModel): |
| config_class = ConformerConfig |
|
|
| def __init__( |
| self, |
| config, |
| ) -> None: |
| super().__init__(config) |
| self.time_reduction = _TimeReduction(config.time_reduction_stride) |
| self.input_linear = torch.nn.Linear( |
| config.input_dim * config.time_reduction_stride, |
| config.conformer_input_dim) |
| self.conformer = Conformer( |
| num_layers=config.conformer_num_layers, |
| input_dim=config.conformer_input_dim, |
| ffn_dim=config.conformer_ffn_dim, |
| num_heads=config.conformer_num_heads, |
| depthwise_conv_kernel_size=config.conformer_depthwise_conv_kernel_size, |
| dropout=config.conformer_dropout, |
| use_group_norm=True, |
| convolution_first=True, |
| ) |
| self.output_linear = torch.nn.Linear(config.conformer_input_dim, config.output_dim) |
|
|
| def forward(self, inputs, lengths, labels=None): |
| time_reduction_out, time_reduction_lengths = self.time_reduction(inputs, lengths) |
| input_linear_out = self.input_linear(time_reduction_out) |
| x, input_lengths = self.conformer(input_linear_out, time_reduction_lengths) |
| logits = self.output_linear(x) |
|
|
| loss = None |
| if labels is not None: |
| labels_mask = labels >= 0 |
| target_lengths = labels_mask.sum(-1) |
| flattened_targets = labels.masked_select(labels_mask) |
| log_probs = nn.functional.log_softmax( |
| logits, |
| dim=-1, |
| dtype=torch.float32 |
| ).transpose(0, 1) |
|
|
| with torch.backends.cudnn.flags(enabled=False): |
| loss = nn.functional.ctc_loss( |
| log_probs, |
| flattened_targets, |
| input_lengths, |
| target_lengths, |
| blank=self.config.pad_token_id, |
| reduction=self.config.ctc_loss_reduction, |
| zero_infinity=self.config.ctc_zero_infinity, |
| ) |
|
|
| output = (logits, input_lengths) |
| return ((loss,) + output) if loss is not None else output |
|
|