| from typing import Optional |
|
|
| import torch |
| from torch import nn |
| from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput |
| from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WHISPER_ATTENTION_CLASSES |
|
|
| from .config import DiCoWConfig |
|
|
|
|
| class CustomLinear(nn.Linear): |
| def __init__(self, *args, init_eye_val=0.0, is_diagonal=False, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.init_eye_val = init_eye_val |
|
|
|
|
| class CustomDiagonalLinear(nn.Module): |
| def __init__(self, d_model, bias=True, init_eye_val=0.0): |
| super().__init__() |
| self.init_eye_val = init_eye_val |
| self.weight = nn.Parameter(torch.full((d_model,), init_eye_val)) |
| self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None |
|
|
| def forward(self, input): |
| out = input * self.weight |
| if self.bias is not None: |
| out += self.bias |
| return out |
|
|
|
|
| class FDDT(nn.Module): |
| def __init__(self, d_model, non_target_rate=0.01, is_diagonal=False, bias_only=False, use_silence=True, |
| use_target=True, use_overlap=True, use_non_target=True, use_interaction=False, |
| scb_module: Optional[nn.Module] = None, ): |
| super().__init__() |
| if use_target: |
| self.target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( |
| CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model, |
| d_model, |
| bias=True, |
| init_eye_val=1.0)) |
| if use_non_target: |
| self.non_target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( |
| CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear( |
| d_model, d_model, bias=True, init_eye_val=non_target_rate)) |
| if use_overlap: |
| self.overlap_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( |
| CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model, |
| d_model, |
| bias=True, |
| init_eye_val=1.0)) |
| if use_silence: |
| self.silence_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( |
| CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear( |
| d_model, d_model, bias=True, init_eye_val=non_target_rate)) |
|
|
| if use_interaction: |
| self.scb = scb_module if scb_module is not None else (nn.Parameter(torch.zeros(d_model)) if bias_only else ( |
| CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear( |
| d_model, d_model, bias=True, init_eye_val=1.0))) |
|
|
| self.use_silence = use_silence |
| self.use_target = use_target |
| self.use_overlap = use_overlap |
| self.use_non_target = use_non_target |
| self.use_interaction = use_interaction |
| self.bias_only = bias_only |
|
|
| @staticmethod |
| def mask_out_non_interaction_signal(hidden_states, mask): |
| mask = torch.round(mask).bool() |
| masked_hidden_states = hidden_states * mask |
| return masked_hidden_states |
|
|
| def forward(self, hidden_states, stno_mask): |
| stno_mask = stno_mask.to(hidden_states.device)[..., None] |
| if self.bias_only: |
| if self.use_silence: |
| hidden_states += stno_mask[:, 0, ...] * self.silence_linear |
| if self.use_target: |
| hidden_states += stno_mask[:, 1, ...] * self.target_linear |
| if self.use_non_target: |
| hidden_states += stno_mask[:, 2, ...] * self.non_target_linear |
| if self.use_overlap: |
| hidden_states += stno_mask[:, 3, ...] * self.overlap_linear |
| if self.use_interaction: |
| hidden_states += stno_mask[:, 4, ...] * self.scb |
| else: |
| orig_hidden_states = hidden_states |
| hidden_states = (self.silence_linear( |
| orig_hidden_states) if self.use_silence else orig_hidden_states) * stno_mask[:, 0, :] + \ |
| (self.target_linear( |
| orig_hidden_states) if self.use_target else orig_hidden_states) * stno_mask[:, 1, :] + \ |
| (self.non_target_linear( |
| orig_hidden_states) if self.use_non_target else orig_hidden_states) * stno_mask[:, 2, |
| :] + \ |
| (self.overlap_linear( |
| orig_hidden_states) if self.use_overlap else orig_hidden_states) * stno_mask[:, 3, :] + \ |
| (self.scb( |
| self.mask_out_non_interaction_signal(orig_hidden_states, |
| stno_mask[:, 4, :])) * stno_mask[:, 4, |
| :] if self.use_interaction else ( |
| 0 if stno_mask.size( |
| 1) == 4 else orig_hidden_states * stno_mask[:, 4, |
| :])) |
| return hidden_states |
|
|
|
|
| class DiCoWEncoder(WhisperEncoder): |
| config_class = DiCoWConfig |
|
|
| def __init__(self, config: DiCoWConfig): |
| super().__init__(config) |
| self.ctc_weight = config.ctc_weight |
| if config.additional_layer and self.ctc_weight > 0.0: |
| self.additional_layer = WhisperEncoderLayer(config) |
| if config.additional_self_attention_layer and self.ctc_weight > 0.0: |
| self.additional_self_attention_layer = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( |
| embed_dim=config.d_model, |
| num_heads=config.encoder_attention_heads, |
| dropout=config.attention_dropout, |
| config=config, |
| ) |
| if config.sub_sample and self.ctc_weight > 0.0: |
| self.subsample_conv1 = nn.Conv1d( |
| in_channels=config.d_model, |
| out_channels=config.d_model, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| bias=False, |
| ) |
| self.subsample_conv2 = nn.Conv1d( |
| in_channels=config.d_model, |
| out_channels=config.d_model, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| bias=False, |
| ) |
| if self.ctc_weight > 0.0: |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size + 1, bias=False) |
| self.final_dropout = nn.Dropout(config.final_dropout) |
| if config.use_fddt: |
| num_fddts = self.config.apply_fddt_to_n_layers if self.config.apply_fddt_to_n_layers != -1 else len( |
| self.layers) |
| self.initial_fddt = FDDT(config.d_model, |
| non_target_rate=config.non_target_fddt_value, |
| is_diagonal=config.fddt_is_diagonal, |
| bias_only=config.fddt_bias_only, |
| use_silence=config.fddt_use_silence, |
| use_target=config.fddt_use_target, |
| use_overlap=config.fddt_use_overlap, |
| use_non_target=config.fddt_use_non_target) |
| is_mt = config.mt_num_speakers > 1 |
| num_scbs = (self.config.scb_layers if self.config.scb_layers != -1 else len( |
| self.layers)) if is_mt else 0 |
| self.scbs_identity_layers = config.encoder_layers - num_scbs |
| self.fddts = nn.ModuleList([ |
| FDDT(config.d_model, |
| non_target_rate=1.0, |
| is_diagonal=config.fddt_is_diagonal, |
| bias_only=config.fddt_bias_only, |
| use_silence=config.fddt_use_silence, |
| use_target=config.fddt_use_target, |
| use_overlap=config.fddt_use_overlap, |
| use_non_target=config.fddt_use_non_target, |
| use_interaction=is_mt, |
| ) |
| for i in range(num_fddts) |
| ]) |
| self.first_task_token = self.config.vocab_size - 30 * 50 - 1 - 6 |
| self.post_init() |
|
|
| @classmethod |
| def _load_pretrained_model( |
| cls, |
| model, |
| state_dict, |
| loaded_keys, |
| resolved_archive_file, |
| pretrained_model_name_or_path, |
| **kwargs |
| ): |
| for key in list(state_dict.keys()): |
| if key.startswith("encoder."): |
| state_dict[key[8:]] = state_dict.pop(key) |
| loaded_keys.remove(key) |
| loaded_keys.append(key[8:]) |
| output = super()._load_pretrained_model( |
| model, |
| state_dict, |
| loaded_keys, |
| resolved_archive_file, |
| pretrained_model_name_or_path, |
| **kwargs |
| ) |
| return output |
|
|
| def get_loss(self, logits, labels): |
| if labels.max() >= self.config.vocab_size: |
| raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") |
| if self.config.remove_timestamps_from_ctc: |
| labels = torch.nn.utils.rnn.pad_sequence([label[label < self.first_task_token] for label in labels], |
| padding_value=-100).T |
| input_lengths = torch.full((logits.shape[0],), fill_value=logits.shape[1], |
| device=logits.device) |
|
|
| |
| |
| labels_mask = labels >= 0 |
| target_lengths = labels_mask.sum(-1) |
| |
|
|
| |
| log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) |
|
|
| with torch.backends.cudnn.flags(enabled=True): |
| ctc_loss = nn.functional.ctc_loss( |
| log_probs, |
| labels, |
| input_lengths, |
| target_lengths, |
| blank=logits.shape[-1] - 1, |
| reduction=self.config.ctc_loss_reduction, |
| zero_infinity=True, |
| ) |
| return ctc_loss |
|
|
| def forward( |
| self, |
| input_features, |
| attention_mask=None, |
| head_mask=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| stno_mask=None, |
| per_group_sizes=None |
| ): |
| |
| |
| |
| expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] |
| if input_features.shape[-1] != expected_seq_length: |
| if input_features.shape[-1] > expected_seq_length: |
| return CausalLMOutput( |
| logits=None, |
| hidden_states=None, |
| attentions=None, |
| ) |
| else: |
| raise ValueError( |
| f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." |
| ) |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| inputs_embeds = nn.functional.gelu(self.conv1(input_features)) |
| inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) |
|
|
| inputs_embeds = inputs_embeds.permute(0, 2, 1) |
| embed_pos = self.embed_positions.weight |
| if hasattr(self, "shift_embeds") and self.shift_embeds: |
| embed_pos = embed_pos[ |
| torch.clamp(((stno_mask[:, 1, :] + stno_mask[:, 3, :]).cumsum(dim=-1) - 1), min=0).to(torch.long)] |
|
|
| if self.config.use_fddt: |
| inputs_embeds = self.initial_fddt(inputs_embeds, stno_mask) |
|
|
| hidden_states = inputs_embeds + embed_pos |
|
|
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
| encoder_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
|
|
| |
| if head_mask is not None: |
| assert head_mask.size()[0] == ( |
| len(self.layers) |
| ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." |
|
|
| for idx, encoder_layer in enumerate(self.layers): |
| if output_hidden_states: |
| encoder_states = encoder_states + (hidden_states,) |
| |
| to_drop = False |
| if self.training: |
| dropout_probability = torch.rand([]) |
| if dropout_probability < self.layerdrop: |
| to_drop = True |
|
|
| if self.config.use_fddt and idx < len(self.fddts): |
| hidden_states = self.fddts[idx](hidden_states, stno_mask) |
|
|
| if to_drop: |
| layer_outputs = (None, None) |
| else: |
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| encoder_layer.__call__, |
| hidden_states, |
| None, |
| (head_mask[idx] if head_mask is not None else None), |
| output_attentions, |
| ) |
| else: |
| layer_outputs = encoder_layer( |
| hidden_states, |
| None, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| output_attentions=output_attentions, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[1],) |
|
|
| hidden_states = self.layer_norm(hidden_states) |
| if output_hidden_states: |
| encoder_states = encoder_states + (hidden_states,) |
|
|
| if not return_dict: |
| outputs = tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) |
| else: |
| outputs = BaseModelOutput( |
| last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions |
| ) |
|
|
| if hasattr(self, "additional_layer"): |
| inter_output, = self.additional_layer( |
| outputs.last_hidden_state, |
| attention_mask=None, |
| output_attentions=output_attentions, |
| layer_head_mask=None, |
| ) |
| elif hasattr(self, "additional_self_attention_layer"): |
| inter_output, _, __ = self.additional_self_attention_layer( |
| outputs.last_hidden_state, |
| attention_mask=None, |
| output_attentions=output_attentions, |
| layer_head_mask=None, |
| ) |
| else: |
| inter_output = outputs.last_hidden_state |
|
|
| inter_output = self.final_dropout(inter_output) |
| if hasattr(self, "subsample_conv2"): |
| inter_output = self.subsample_conv2(self.subsample_conv1(inter_output.transpose(1, 2))).transpose(1, 2) |
| if self.ctc_weight > 0.0: |
| logits = self.lm_head(inter_output) |
| else: |
| logits = None |
|
|
| return CausalLMOutput( |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|