| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers.models.seamless_m4t.modeling_seamless_m4t import ( |
| _compute_new_attention_mask, |
| ) |
| from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import ( |
| SeamlessM4Tv2SpeechEncoder, |
| SeamlessM4Tv2PreTrainedModel, |
| ) |
| from .configuration_seamless_m4t_v2_speech_encoder import ( |
| MODEL_TYPE, |
| SeamlessM4Tv2EncoderConfig, |
| ) |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
| from transformers.models.auto import ( |
| AutoModel, |
| AutoModelForAudioClassification, |
| AutoModelForSequenceClassification, |
| ) |
|
|
|
|
| class SeamlessM4Tv2SpeechEncoder(SeamlessM4Tv2SpeechEncoder): |
| model_type = MODEL_TYPE |
| config_class = SeamlessM4Tv2EncoderConfig |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| @staticmethod |
| def mean_pooling( |
| hidden_states: torch.Tensor, attention_mask: torch.Tensor |
| ) -> torch.Tensor: |
| |
| |
|
|
| |
| input_mask_expanded = ( |
| attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() |
| ) |
| sum_hidden_states = torch.sum(hidden_states * input_mask_expanded, 1) |
| sum_mask = input_mask_expanded.sum(1) |
|
|
| return sum_hidden_states / torch.clamp(sum_mask, min=1e-9) |
|
|
|
|
| class SeamlessM4Tv2ForAudioClassification(SeamlessM4Tv2PreTrainedModel): |
| model_type = MODEL_TYPE |
| base_model_prefix = "model" |
| config_class = SeamlessM4Tv2EncoderConfig |
|
|
| def __init__(self, config, *args, **kwargs): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
|
|
| self.model = SeamlessM4Tv2SpeechEncoder(config) |
| self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
|
| def forward( |
| self, |
| input_features: torch.Tensor, |
| attention_mask: torch.Tensor, |
| labels: None | torch.Tensor, |
| *args, |
| **kwargs, |
| ): |
| output_hidden_states = kwargs.pop("output_hidden_states", False) |
| outputs = self.model( |
| input_features, |
| attention_mask, |
| output_hidden_states=output_hidden_states, |
| *args, |
| **kwargs, |
| ) |
| hidden_states = outputs.last_hidden_state |
| if attention_mask is not None and self.model.config.add_adapter: |
| sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask( |
| attention_mask |
| ).to(outputs.last_hidden_state.device) |
| attention_mask = _compute_new_attention_mask( |
| hidden_states=hidden_states, seq_lens=sub_sampled_lengths |
| ) |
| hidden_states = self.model.mean_pooling( |
| outputs.last_hidden_state, attention_mask |
| ) |
| logits = self.score(hidden_states) |
|
|
| if labels is not None: |
| |
| labels = labels.to(logits.device) |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and ( |
| labels.dtype == torch.long or labels.dtype == torch.int |
| ): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
| if self.config.problem_type == "regression": |
| loss_fct = F.mse_loss |
| if self.num_labels == 1: |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = F.cross_entropy |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = F.binary_cross_entropy_with_logits |
| loss = loss_fct(logits, labels) |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states if output_hidden_states else None, |
| ) |
|
|
|
|
| AutoModel.register(SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2SpeechEncoder) |
| AutoModelForAudioClassification.register( |
| SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification |
| ) |
| AutoModelForSequenceClassification.register( |
| SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification |
| ) |
|
|