| from typing import Optional, Tuple, Union |
|
|
| from dataclasses import dataclass |
|
|
| |
| import torch |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
|
|
| from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
| from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput, BaseModelOutput |
|
|
| from transformers.models.whisper.generation_whisper import WhisperGenerationMixin |
| from transformers.models.whisper.configuration_whisper import WhisperConfig |
| from transformers import WhisperPreTrainedModel |
| |
| from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer |
| from transformers import WhisperModel, WhisperForConditionalGeneration |
|
|
| class WhisperMultitaskConfig(WhisperConfig): |
| model_type = "whisper" |
| keys_to_ignore_at_inference = ["past_key_values"] |
| attribute_map = { |
| "num_key_value_heads": "encoder_attention_heads", |
| "num_attention_heads": "encoder_attention_heads", |
| "hidden_size": "d_model", |
| } |
|
|
| def __init__( |
| self, |
| ctc_char_dropout=0.1, |
| ctc_char_vocab_size=33, |
| ctc_char_layers=0, |
| ctc_char_hidden_layer=-1, |
| ctc_phoneme_dropout=0.1, |
| ctc_phoneme_vocab_size=33, |
| ctc_phoneme_layers=0, |
| ctc_phoneme_hidden_layer=-1, |
| vad_hidden_layer=-1, |
| vad_layers=0, |
| diarization_hidden_layer=-1, |
| diarization_max_speakers=5, |
| diarization_layers=0, |
| ctc_loss_reduction='mean', |
| ctc_zero_infinity=True, |
| **kwargs, |
| ): |
| self.ctc_char_dropout = ctc_char_dropout |
| self.ctc_char_vocab_size = ctc_char_vocab_size |
| self.ctc_char_layers = ctc_char_layers |
| self.ctc_char_hidden_layer = ctc_char_hidden_layer |
| self.ctc_phoneme_dropout = ctc_phoneme_dropout |
| self.ctc_phoneme_vocab_size = ctc_phoneme_vocab_size |
| self.ctc_phoneme_layers = ctc_phoneme_layers |
| self.ctc_phoneme_hidden_layer = ctc_phoneme_hidden_layer |
| self.vad_hidden_layer = vad_hidden_layer |
| self.vad_layers = vad_layers |
| self.diarization_hidden_layer = diarization_hidden_layer |
| self.diarization_max_speakers = diarization_max_speakers |
| self.diarization_layers = diarization_layers |
| self.ctc_loss_reduction = ctc_loss_reduction |
| self.ctc_zero_infinity = ctc_zero_infinity |
|
|
| |
| |
| super().__init__( |
| **kwargs, |
| ) |
|
|
|
|
| @dataclass |
| class Seq2SeqMultitaskLMOutput(Seq2SeqLMOutput): |
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
| cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
| encoder_last_hidden_state: Optional[torch.FloatTensor] = None |
| encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
| ctc_char_logits: Optional[torch.FloatTensor] = None |
| ctc_phoneme_logits: Optional[torch.FloatTensor] = None |
| vad_logits: Optional[torch.FloatTensor] = None |
| diarization_logits: Optional[torch.FloatTensor] = None |
|
|
|
|
| |
| class WhisperMultitask(WhisperForConditionalGeneration): |
| config_class = WhisperMultitaskConfig |
| base_model_prefix = "model" |
| _tied_weights_keys = ["proj_out.weight"] |
|
|
| def __init__(self, config: WhisperMultitaskConfig): |
| super().__init__(config) |
| self.model = WhisperModel(config) |
| self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| self.max_target_positions = config.max_target_positions |
|
|
| |
| ctc_char_dropout = config.ctc_char_dropout if hasattr(config, "ctc_char_dropout") and config.ctc_char_dropout else 0.1 |
| self.dropout_ctc_char = nn.Dropout(ctc_char_dropout) |
| ctc_char_vocab_size = config.ctc_char_vocab_size if hasattr(config, "ctc_char_vocab_size") and config.ctc_char_vocab_size else 33 |
| output_hidden_size = ( |
| config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size |
| ) |
| self.ctc_char_lm_head = nn.Linear(output_hidden_size, ctc_char_vocab_size, bias=True) |
| self.ctc_char_hidden_layer = config.ctc_char_hidden_layer if hasattr(config, "ctc_char_hidden_layer") and config.ctc_char_hidden_layer else -1 |
|
|
| self.ctc_char_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.ctc_char_layers)]) |
| self.ctc_char_layer_norm = nn.LayerNorm(config.d_model) |
| |
|
|
| |
| ctc_phoneme_dropout = config.ctc_phoneme_dropout if hasattr(config, "ctc_phoneme_dropout") and config.ctc_phoneme_dropout else 0.1 |
| self.dropout_ctc_phoneme = nn.Dropout(ctc_phoneme_dropout) |
| ctc_phoneme_vocab_size = config.ctc_phoneme_vocab_size if hasattr(config, "ctc_phoneme_vocab_size") and config.ctc_phoneme_vocab_size else 33 |
| output_hidden_size = ( |
| config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size |
| ) |
| self.ctc_phoneme_lm_head = nn.Linear(output_hidden_size, ctc_phoneme_vocab_size, bias=True) |
| self.ctc_phoneme_hidden_layer = config.ctc_phoneme_hidden_layer if hasattr(config, "ctc_phoneme_hidden_layer") and config.ctc_phoneme_hidden_layer else -1 |
|
|
| self.ctc_phoneme_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.ctc_phoneme_layers)]) |
| self.ctc_phoneme_layer_norm = nn.LayerNorm(config.d_model) |
| |
| |
| self.vad_classifier = nn.Linear(config.hidden_size, 1, bias=True) |
| self.vad_hidden_layer = config.vad_hidden_layer if hasattr(config, "vad_hidden_layer") and config.vad_hidden_layer else -1 |
|
|
| self.vad_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.vad_layers)]) |
| self.vad_layer_norm = nn.LayerNorm(config.d_model) |
| |
| |
| self.diarization_max_speakers = config.diarization_max_speakers if hasattr(config, "diarization_max_speakers") and config.diarization_max_speakers else 5 |
| self.diarization_classifier = nn.Linear(config.hidden_size, self.diarization_max_speakers, bias=True) |
| self.diarization_hidden_layer = config.diarization_hidden_layer if hasattr(config, "diarization_hidden_layer") and config.diarization_hidden_layer else -1 |
| self.diarization_layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.diarization_layers)]) |
| self.diarization_layer_norm = nn.LayerNorm(config.d_model) |
| |
| |
| self.ctc_loss_reduction = config.ctc_loss_reduction if hasattr(config, "ctc_loss_reduction") and config.ctc_loss_reduction else "mean" |
| self.ctc_zero_infinity = config.ctc_zero_infinity if hasattr(config, "ctc_zero_infinity") and config.ctc_zero_infinity else True |
|
|
| if config.use_weighted_layer_sum: |
| num_layers = config.num_hidden_layers + 1 |
| self.ctc_char_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
| self.ctc_phoneme_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
| self.vad_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
| self.diarization_layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
| |
| |
| self.post_init() |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| def forward( |
| self, |
| input_features: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| decoder_input_ids: Optional[torch.LongTensor] = None, |
| decoder_attention_mask: Optional[torch.LongTensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| decoder_head_mask: Optional[torch.Tensor] = None, |
| cross_attn_head_mask: Optional[torch.Tensor] = None, |
| encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, |
| decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, |
| decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` |
| or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is |
| only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> import torch |
| >>> from transformers import AutoProcessor, WhisperForConditionalGeneration |
| >>> from datasets import load_dataset |
| |
| >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") |
| >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") |
| |
| >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
| |
| >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") |
| >>> input_features = inputs.input_features |
| |
| >>> generated_ids = model.generate(inputs=input_features) |
| |
| >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| >>> transcription |
| ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' |
| ```""" |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if labels is not None: |
| if labels.shape[1] > self.max_target_positions: |
| raise ValueError( |
| f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens." |
| ) |
| if decoder_input_ids is None and decoder_inputs_embeds is None: |
| decoder_input_ids = shift_tokens_right( |
| labels, self.config.pad_token_id, self.config.decoder_start_token_id |
| ) |
|
|
| |
| if True: |
| |
| outputs = self.model( |
| input_features, |
| attention_mask=attention_mask, |
| decoder_input_ids=decoder_input_ids, |
| encoder_outputs=encoder_outputs, |
| decoder_attention_mask=decoder_attention_mask, |
| head_mask=head_mask, |
| decoder_head_mask=decoder_head_mask, |
| cross_attn_head_mask=cross_attn_head_mask, |
| past_key_values=past_key_values, |
| decoder_inputs_embeds=decoder_inputs_embeds, |
| decoder_position_ids=decoder_position_ids, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
| |
| |
| |
| |
| |
| |
|
|
| if output_hidden_states and outputs.encoder_hidden_states: |
|
|
| if self.config.use_weighted_layer_sum: |
| ctc_hidden_states = outputs.encoder_hidden_states |
| ctc_hidden_states = torch.stack(ctc_hidden_states, dim=1) |
| norm_weights = nn.functional.softmax(self.ctc_char_layer_weights, dim=-1) |
| ctc_char_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
| norm_weights = nn.functional.softmax(self.ctc_phoneme_layer_weights, dim=-1) |
| ctc_phoneme_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
| norm_weights = nn.functional.softmax(self.vad_layer_weights, dim=-1) |
| vad_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
| norm_weights = nn.functional.softmax(self.diarization_layer_weights, dim=-1) |
| diarization_hidden_states = (ctc_hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
| else: |
| ctc_char_hidden_states = outputs.encoder_hidden_states[self.ctc_char_hidden_layer] |
| ctc_phoneme_hidden_states = outputs.encoder_hidden_states[self.ctc_phoneme_hidden_layer] |
| vad_hidden_states = outputs.encoder_hidden_states[self.vad_hidden_layer] |
| diarization_hidden_states = outputs.encoder_hidden_states[self.diarization_hidden_layer] |
|
|
|
|
| |
| for idx, encoder_layer in enumerate(self.ctc_char_layers): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if True: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| layer_outputs = encoder_layer( |
| ctc_char_hidden_states, |
| None, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| output_attentions=output_attentions, |
| ) |
| |
| ctc_char_hidden_states = layer_outputs[0] |
| |
| |
| |
| |
| ctc_char_hidden_states = self.ctc_char_layer_norm(ctc_char_hidden_states) |
|
|
|
|
| |
| for idx, encoder_layer in enumerate(self.ctc_phoneme_layers): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if True: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| layer_outputs = encoder_layer( |
| ctc_phoneme_hidden_states, |
| None, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| output_attentions=output_attentions, |
| ) |
| |
| ctc_phoneme_hidden_states = layer_outputs[0] |
| |
| |
| |
| |
| ctc_phoneme_hidden_states = self.ctc_char_layer_norm(ctc_phoneme_hidden_states) |
| |
| |
| ctc_char_hidden_states = self.dropout_ctc_char(ctc_char_hidden_states) |
| ctc_phoneme_hidden_states = self.dropout_ctc_phoneme(ctc_phoneme_hidden_states) |
| |
| ctc_char_logits = self.ctc_char_lm_head(ctc_char_hidden_states) |
| ctc_phoneme_logits = self.ctc_phoneme_lm_head(ctc_phoneme_hidden_states) |
|
|
| |
| for idx, encoder_layer in enumerate(self.vad_layers): |
| if True: |
| layer_outputs = encoder_layer( |
| vad_hidden_states, |
| None, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| output_attentions=output_attentions, |
| ) |
| |
| vad_hidden_states = layer_outputs[0] |
| vad_hidden_states = self.vad_layer_norm(vad_hidden_states) |
| |
| vad_logits = torch.sigmoid(self.vad_classifier(vad_hidden_states)) |
|
|
| |
| for idx, encoder_layer in enumerate(self.diarization_layers): |
| if True: |
| layer_outputs = encoder_layer( |
| diarization_hidden_states, |
| None, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| output_attentions=output_attentions, |
| ) |
| |
| diarization_hidden_states = layer_outputs[0] |
| diarization_hidden_states = self.diarization_layer_norm(diarization_hidden_states) |
| |
| diarization_logits = torch.sigmoid(self.diarization_classifier(diarization_hidden_states)) |
| |
| |
| lm_logits = self.proj_out(outputs[0]) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = CrossEntropyLoss() |
| |
| labels = labels.to(lm_logits.device) |
| loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) |
|
|
| if not return_dict: |
| output = (lm_logits,) + outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| if output_hidden_states and outputs.encoder_hidden_states: |
| return Seq2SeqMultitaskLMOutput( |
| loss=loss, |
| logits=lm_logits, |
| past_key_values=outputs.past_key_values, |
| decoder_hidden_states=outputs.decoder_hidden_states, |
| decoder_attentions=outputs.decoder_attentions, |
| cross_attentions=outputs.cross_attentions, |
| encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
| encoder_hidden_states=outputs.encoder_hidden_states, |
| encoder_attentions=outputs.encoder_attentions, |
| ctc_char_logits=ctc_char_logits, |
| ctc_phoneme_logits=ctc_phoneme_logits, |
| vad_logits=vad_logits, |
| diarization_logits=diarization_logits, |
| ) |
| |
| return Seq2SeqLMOutput( |
| loss=loss, |
| logits=lm_logits, |
| past_key_values=outputs.past_key_values, |
| decoder_hidden_states=outputs.decoder_hidden_states, |
| decoder_attentions=outputs.decoder_attentions, |
| cross_attentions=outputs.cross_attentions, |
| encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
| encoder_hidden_states=outputs.encoder_hidden_states, |
| encoder_attentions=outputs.encoder_attentions, |
| ) |
| |
|
|