| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Optional, Union, Tuple, List |
| | from transformers import LlamaForSequenceClassification, Cache |
| | from transformers.modeling_outputs import SequenceClassifierOutputWithPast |
| |
|
| |
|
| | |
| | class LabelSmoothingCrossEntropy(nn.Module): |
| | def __init__(self, eps: float = 0.1): |
| | super().__init__() |
| | self.eps = eps |
| |
|
| | def forward(self, pred, target): |
| | num_classes = pred.size(-1) |
| | log_preds = F.log_softmax(pred, dim=-1) |
| | nll = F.nll_loss(log_preds, target, reduction='none') |
| | smooth_loss = -log_preds.sum(dim=-1) |
| | loss = (1 - self.eps) * nll + self.eps * smooth_loss / num_classes |
| | return loss.mean() |
| |
|
| |
|
| | |
| | class Weights(nn.Module): |
| | def __init__(self, hidden_size: int = 4096, num_labels: int = 65): |
| | super().__init__() |
| | self.fc = nn.Sequential( |
| | nn.Linear(hidden_size, hidden_size), |
| | nn.SELU(), |
| | nn.Linear(hidden_size, hidden_size), |
| | nn.SELU(), |
| | nn.Linear(hidden_size, num_labels) |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.fc(x) |
| |
|
| |
|
| | |
| | class LlamaForSequenceClassificationWithCustomHead(LlamaForSequenceClassification): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.weights = Weights(hidden_size=config.hidden_size, num_labels=65) |
| | self.loss_fn = LabelSmoothingCrossEntropy(eps=0.1) |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
| |
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | transformer_outputs = self.model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | hidden_states = transformer_outputs[0] |
| |
|
| | if input_ids is not None: |
| | batch_size = input_ids.shape[0] |
| | else: |
| | batch_size = inputs_embeds.shape[0] |
| |
|
| | if self.config.pad_token_id is None and batch_size != 1: |
| | raise ValueError("Padding token ID must be set for batch size > 1.") |
| |
|
| | |
| | if self.config.pad_token_id is not None and input_ids is not None: |
| | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(dim=-1) - 1 |
| | sequence_lengths = sequence_lengths % input_ids.shape[-1] |
| | else: |
| | sequence_lengths = -1 |
| |
|
| | pooled_hidden = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] |
| | logits = self.weights(pooled_hidden) |
| |
|
| | loss = None |
| | if labels is not None: |
| | labels = labels.to(logits.device) |
| | loss = self.loss_fn(logits, labels) |
| |
|
| | if not return_dict: |
| | output = (logits,) |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return SequenceClassifierOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=transformer_outputs.past_key_values, |
| | hidden_states=transformer_outputs.hidden_states, |
| | attentions=transformer_outputs.attentions, |
| | ) |
| |
|