| | import math |
| | import warnings |
| | from typing import Union, Tuple, Optional |
| |
|
| | import numpy as np |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutput |
| | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
| | from transformers.integrations.fsdp import is_fsdp_managed_module |
| | from transformers.models.hubert.modeling_hubert import ( |
| | HubertFeatureEncoder, |
| | HubertFeatureProjection, |
| | HubertEncoderStableLayerNorm, |
| | HubertEncoder, |
| | _HIDDEN_STATES_START_POSITION |
| | ) |
| |
|
| | from .configuration_hubert_spkreg import HubertSpkRegConfig |
| |
|
| |
|
| | class HubertSpkRegPreTrainedModel(PreTrainedModel): |
| | """ |
| | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| | models. |
| | """ |
| |
|
| | config_class = HubertSpkRegConfig |
| | base_model_prefix = "hubert" |
| | main_input_name = "input_values" |
| | supports_gradient_checkpointing = True |
| | _supports_flash_attn_2 = True |
| | _supports_sdpa = True |
| |
|
| | def _init_weights(self, module): |
| | """Initialize the weights""" |
| | if isinstance(module, nn.Linear): |
| | |
| | |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| | elif isinstance(module, nn.Conv1d): |
| | if is_deepspeed_zero3_enabled(): |
| | import deepspeed |
| |
|
| | if hasattr(module, "weight_v") and hasattr(module, "weight_g"): |
| | with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): |
| | nn.init.kaiming_normal_(module.weight.data) |
| | else: |
| | with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): |
| | nn.init.kaiming_normal_(module.weight.data) |
| | else: |
| | nn.init.kaiming_normal_(module.weight.data) |
| |
|
| | if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: |
| | module.bias.data.zero_() |
| |
|
| | def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): |
| | """ |
| | Computes the output length of the convolutional layers |
| | """ |
| |
|
| | def _conv_out_length(input_length, kernel_size, stride): |
| | |
| | |
| | return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 |
| |
|
| | for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): |
| | input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
| |
|
| | return input_lengths |
| |
|
| | def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): |
| | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) |
| | batch_size = attention_mask.shape[0] |
| |
|
| | attention_mask = torch.zeros( |
| | (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device |
| | ) |
| | |
| | attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 |
| | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() |
| | return attention_mask |
| | |
| |
|
| | |
| | def _compute_mask_indices( |
| | shape: Tuple[int, int], |
| | mask_prob: float, |
| | mask_length: int, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | min_masks: int = 0, |
| | ) -> np.ndarray: |
| | """ |
| | Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for |
| | ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on |
| | CPU as part of the preprocessing during training. |
| | |
| | Args: |
| | shape: The shape for which to compute masks. This should be of a tuple of size 2 where |
| | the first element is the batch size and the second element is the length of the axis to span. |
| | mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of |
| | independently generated mask spans of length `mask_length` is computed by |
| | `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the |
| | actual percentage will be smaller. |
| | mask_length: size of the mask |
| | min_masks: minimum number of masked spans |
| | attention_mask: A (right-padded) attention mask which independently shortens the feature axis of |
| | each batch dimension. |
| | """ |
| | batch_size, sequence_length = shape |
| |
|
| | if mask_length < 1: |
| | raise ValueError("`mask_length` has to be bigger than 0.") |
| |
|
| | if mask_length > sequence_length: |
| | raise ValueError( |
| | f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" |
| | f" and `sequence_length`: {sequence_length}`" |
| | ) |
| |
|
| | |
| | epsilon = np.random.rand(1).item() |
| |
|
| | def compute_num_masked_span(input_length): |
| | """Given input length, compute how many spans should be masked""" |
| | num_masked_span = int(mask_prob * input_length / mask_length + epsilon) |
| | num_masked_span = max(num_masked_span, min_masks) |
| |
|
| | |
| | if num_masked_span * mask_length > sequence_length: |
| | num_masked_span = sequence_length // mask_length |
| |
|
| | |
| | if input_length - (mask_length - 1) < num_masked_span: |
| | num_masked_span = max(input_length - (mask_length - 1), 0) |
| |
|
| | return num_masked_span |
| |
|
| | |
| | input_lengths = ( |
| | attention_mask.sum(-1).detach().tolist() |
| | if attention_mask is not None |
| | else [sequence_length for _ in range(batch_size)] |
| | ) |
| |
|
| | |
| | spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) |
| | spec_aug_mask_idxs = [] |
| |
|
| | max_num_masked_span = compute_num_masked_span(sequence_length) |
| |
|
| | if max_num_masked_span == 0: |
| | return spec_aug_mask |
| |
|
| | for input_length in input_lengths: |
| | |
| | num_masked_span = compute_num_masked_span(input_length) |
| |
|
| | |
| | spec_aug_mask_idx = np.random.choice( |
| | np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False |
| | ) |
| |
|
| | |
| | |
| | |
| | if len(spec_aug_mask_idx) == 0: |
| | |
| | |
| | |
| | dummy_mask_idx = sequence_length - 1 |
| | else: |
| | dummy_mask_idx = spec_aug_mask_idx[0] |
| |
|
| | spec_aug_mask_idx = np.concatenate( |
| | [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] |
| | ) |
| | spec_aug_mask_idxs.append(spec_aug_mask_idx) |
| |
|
| | spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) |
| |
|
| | |
| | spec_aug_mask_idxs = np.broadcast_to( |
| | spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) |
| | ) |
| | spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) |
| |
|
| | |
| | offsets = np.arange(mask_length)[None, None, :] |
| | offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( |
| | batch_size, max_num_masked_span * mask_length |
| | ) |
| | spec_aug_mask_idxs = spec_aug_mask_idxs + offsets |
| |
|
| | |
| | if spec_aug_mask_idxs.max() > sequence_length - 1: |
| | spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 |
| |
|
| | |
| | np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) |
| |
|
| | return spec_aug_mask |
| |
|
| |
|
| | class HubertSpkRegModel(HubertSpkRegPreTrainedModel): |
| |
|
| | def __init__(self, config: HubertSpkRegConfig): |
| | super().__init__(config) |
| | self.config = config |
| | self.feature_extractor = HubertFeatureEncoder(config) |
| | self.feature_projection = HubertFeatureProjection(config) |
| |
|
| | if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: |
| | self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) |
| |
|
| | if config.do_stable_layer_norm: |
| | self.encoder = HubertEncoderStableLayerNorm(config) |
| | else: |
| | self.encoder = HubertEncoder(config) |
| |
|
| | |
| | self.post_init() |
| |
|
| | |
| | def _mask_hidden_states( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | mask_time_indices: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | ): |
| | """ |
| | Masks extracted features along time axis and/or along feature axis according to |
| | [SpecAugment](https://arxiv.org/abs/1904.08779). |
| | """ |
| |
|
| | |
| | if not getattr(self.config, "apply_spec_augment", True): |
| | return hidden_states |
| |
|
| | |
| | batch_size, sequence_length, hidden_size = hidden_states.size() |
| |
|
| | if mask_time_indices is not None: |
| | |
| | hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) |
| | elif self.config.mask_time_prob > 0 and self.training: |
| | mask_time_indices = _compute_mask_indices( |
| | (batch_size, sequence_length), |
| | mask_prob=self.config.mask_time_prob, |
| | mask_length=self.config.mask_time_length, |
| | attention_mask=attention_mask, |
| | min_masks=self.config.mask_time_min_masks, |
| | ) |
| | mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) |
| | hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) |
| |
|
| | if self.config.mask_feature_prob > 0 and self.training: |
| | |
| | mask_feature_indices = _compute_mask_indices( |
| | (batch_size, hidden_size), |
| | mask_prob=self.config.mask_feature_prob, |
| | mask_length=self.config.mask_feature_length, |
| | min_masks=self.config.mask_feature_min_masks, |
| | ) |
| | mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) |
| | mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) |
| | hidden_states[mask_feature_indices] = 0 |
| |
|
| | return hidden_states |
| |
|
| | def forward( |
| | self, |
| | input_values: Optional[torch.Tensor], |
| | attention_mask: Optional[torch.Tensor] = None, |
| | mask_time_indices: Optional[torch.FloatTensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple, BaseModelOutput]: |
| | """ |
| | |
| | Returns: |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from transformers import AutoProcessor, HubertModel |
| | >>> from datasets import load_dataset |
| | >>> import soundfile as sf |
| | |
| | >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") |
| | >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") |
| | |
| | |
| | >>> def map_to_array(batch): |
| | ... speech, _ = sf.read(batch["file"]) |
| | ... batch["speech"] = speech |
| | ... return batch |
| | |
| | |
| | >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
| | >>> ds = ds.map(map_to_array) |
| | |
| | >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1 |
| | >>> hidden_states = model(input_values).last_hidden_state |
| | ```""" |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | extract_features = self.feature_extractor(input_values) |
| | extract_features = extract_features.transpose(1, 2) |
| |
|
| | if attention_mask is not None: |
| | |
| | attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) |
| |
|
| | hidden_states = self.feature_projection(extract_features) |
| | hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) |
| |
|
| | encoder_outputs = self.encoder( |
| | hidden_states, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | hidden_states = encoder_outputs[0] |
| |
|
| | if not return_dict: |
| | return (hidden_states,) + encoder_outputs[1:] |
| |
|
| | return BaseModelOutput( |
| | last_hidden_state=hidden_states, |
| | hidden_states=encoder_outputs.hidden_states, |
| | attentions=encoder_outputs.attentions, |
| | ) |
| | |
| |
|
| | class AngularLinear(nn.Module): |
| |
|
| | def __init__(self, in_features: int, out_features: int): |
| | super(AngularLinear, self).__init__() |
| | self.in_features = in_features |
| | self.out_features = out_features |
| | self.weight = torch.nn.Parameter( |
| | torch.FloatTensor(out_features, in_features), requires_grad=True |
| | ) |
| | nn.init.xavier_normal_(self.weight, gain=1) |
| |
|
| | def forward( |
| | self, |
| | inputs: torch.Tensor, |
| | ): |
| | |
| | cosine = F.linear(F.normalize(inputs), F.normalize(self.weight)) |
| | return cosine |
| |
|
| | def extra_repr(self) -> str: |
| | return 'in_features={}, out_features={}'.format( |
| | self.in_features, self.out_features |
| | ) |
| |
|
| |
|
| | class AMSoftmaxLoss(nn.Module): |
| | """Additive Margin Softmax (CosFace). |
| | |
| | Paper: Wang, Feng, et al. "Additive margin softmax for face verification." |
| | IEEE Signal Processing Letters 25.7 (2018): 926-930. |
| | """ |
| | def __init__( |
| | self, |
| | scale: float = 30.0, |
| | margin: float = 0.35, |
| | label_smoothing: float = 0.0, |
| | reduction: str = "mean" |
| | ): |
| | """ |
| | Args: |
| | num_classes: Number of classes (output dimension) |
| | scale: Scaling factor for logits (default: 30.0) |
| | margin: Angular margin (default: 0.35) |
| | """ |
| | super(AMSoftmaxLoss, self).__init__() |
| | self.scale = scale |
| | self.margin = margin |
| | self.label_smoothing = label_smoothing |
| | self.reduction = reduction |
| |
|
| | def forward( |
| | self, |
| | inputs: torch.Tensor, |
| | targets: torch.Tensor, |
| | ): |
| | """ |
| | Args: |
| | inputs: Input features of shape (batch_size, num_labels) |
| | targets: Ground truth labels of shape (batch_size) |
| | label_smoothing: Label smoothing factor (default: 0.0) |
| | reduction: Reduction method (default: "mean") |
| | Returns: |
| | Loss value |
| | """ |
| | _, num_labels = inputs.shape |
| | |
| | cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7) |
| | psi = cos_theta - self.margin |
| | one_hot = nn.functional.one_hot(targets, num_labels) |
| | outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) |
| | loss = F.cross_entropy( |
| | outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction |
| | ) |
| | return loss |
| |
|
| |
|
| | class AAMSoftmaxLoss(nn.Module): |
| | """Additive Angular Margin Softmax (ArcFace). |
| | |
| | Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition." |
| | Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019. |
| | """ |
| | def __init__( |
| | self, |
| | scale: float = 30.0, |
| | margin: float = 0.2, |
| | easy_margin: bool = False, |
| | label_smoothing: float = 0.0, |
| | reduction: str = "mean" |
| | ): |
| | """ |
| | Args: |
| | num_classes: Number of classes (output dimension) |
| | scale: Scaling factor for logits (default: 30.0) |
| | margin: Angular margin (default: 0.35) |
| | easy_margin: Use the easy margin loss (default: False) |
| | """ |
| | super(AAMSoftmaxLoss, self).__init__() |
| | self.scale = scale |
| | self.margin = margin |
| | self.easy_margin = easy_margin |
| | self.label_smoothing = label_smoothing |
| | self.reduction = reduction |
| | |
| | def forward( |
| | self, |
| | inputs: torch.Tensor, |
| | targets: torch.Tensor, |
| | ): |
| | """ |
| | Args: |
| | inputs: Input features of shape (batch_size, num_labels) |
| | targets: Ground truth labels of shape (batch_size) |
| | Returns: |
| | Loss value |
| | """ |
| | _, num_labels = inputs.shape |
| | |
| | epsilon = 1e-6 |
| | |
| | |
| | cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon) |
| | sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) |
| | sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon) |
| |
|
| | cos_m = math.cos(self.margin) |
| | sin_m = math.sin(self.margin) |
| | psi = cos_theta * cos_m - sin_theta * sin_m |
| |
|
| | if self.easy_margin: |
| | psi = torch.where(cos_theta > 0, psi, cos_theta) |
| | else: |
| | |
| | psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin) |
| |
|
| | one_hot = nn.functional.one_hot(targets, num_labels) |
| | outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) |
| | loss = F.cross_entropy( |
| | outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction |
| | ) |
| | return loss |
| | |
| |
|
| | class HubertSpkRegForSequenceClassification(HubertSpkRegPreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | if hasattr(config, "add_adapter") and config.add_adapter: |
| | raise ValueError( |
| | "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)" |
| | ) |
| | self.hubert = HubertSpkRegModel(config) |
| | num_layers = config.num_hidden_layers + 1 |
| | if config.use_weighted_layer_sum: |
| | self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
| | self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) |
| | |
| | if self.config.loss_fct == 'cross_entropy': |
| | self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) |
| | elif self.config.loss_fct == 'additive_margin': |
| | self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels) |
| | elif self.config.loss_fct == 'additive_angular_margin': |
| | self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels) |
| | else: |
| | raise ValueError(f"Unsupported loss function: {self.config.loss_fct}") |
| |
|
| | |
| | self.post_init() |
| |
|
| | def freeze_feature_extractor(self): |
| | """ |
| | Calling this function will disable the gradient computation for the feature encoder so that its parameters will |
| | not be updated during training. |
| | """ |
| | warnings.warn( |
| | "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " |
| | "Please use the equivalent `freeze_feature_encoder` method instead.", |
| | FutureWarning, |
| | ) |
| | self.freeze_feature_encoder() |
| |
|
| | def freeze_feature_encoder(self): |
| | """ |
| | Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
| | not be updated during training. |
| | """ |
| | self.hubert.feature_extractor._freeze_parameters() |
| |
|
| | def freeze_base_model(self): |
| | """ |
| | Calling this function will disable the gradient computation for the base model so that its parameters will not |
| | be updated during training. Only the classification head will be updated. |
| | """ |
| | for param in self.hubert.parameters(): |
| | param.requires_grad = False |
| |
|
| | def forward( |
| | self, |
| | input_values: Optional[torch.Tensor], |
| | attention_mask: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | ) -> Union[Tuple, SequenceClassifierOutput]: |
| | r""" |
| | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| | """ |
| |
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states |
| |
|
| | outputs = self.hubert( |
| | input_values, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | if self.config.use_weighted_layer_sum: |
| | hidden_states = outputs[_HIDDEN_STATES_START_POSITION] |
| | hidden_states = torch.stack(hidden_states, dim=1) |
| | norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) |
| | hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
| | else: |
| | hidden_states = outputs[0] |
| |
|
| | hidden_states = self.projector(hidden_states) |
| | if attention_mask is None: |
| | pooled_output = hidden_states.mean(dim=1) |
| | else: |
| | padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) |
| | hidden_states[~padding_mask] = 0.0 |
| | pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) |
| |
|
| | logits = self.classifier(pooled_output) |
| |
|
| | loss = None |
| | if labels is not None: |
| | if self.config.loss_fct == 'cross_entropy': |
| | loss_fct = nn.CrossEntropyLoss( |
| | label_smoothing=self.config.label_smoothing, |
| | reduction=self.config.reduction |
| | ) |
| | elif self.config.loss_fct == 'additive_margin': |
| | loss_fct = AMSoftmaxLoss( |
| | scale=self.config.scale, |
| | margin=self.config.margin, |
| | label_smoothing=self.config.label_smoothing, |
| | reduction=self.config.reduction |
| | ) |
| | elif self.config.loss_fct == 'additive_angular_margin': |
| | loss_fct = AAMSoftmaxLoss( |
| | scale=self.config.scale, |
| | margin=self.config.margin, |
| | easy_margin=self.config.easy_margin, |
| | label_smoothing=self.config.label_smoothing, |
| | reduction=self.config.reduction |
| | ) |
| | loss = loss_fct( |
| | logits.view(-1, self.config.num_labels), |
| | labels.view(-1), |
| | ) |
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |