import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Union, Tuple, List from transformers import LlamaForSequenceClassification, Cache from transformers.modeling_outputs import SequenceClassifierOutputWithPast # 1. Label smoothing cross entropy loss class LabelSmoothingCrossEntropy(nn.Module): def __init__(self, eps: float = 0.1): super().__init__() self.eps = eps def forward(self, pred, target): num_classes = pred.size(-1) log_preds = F.log_softmax(pred, dim=-1) nll = F.nll_loss(log_preds, target, reduction='none') smooth_loss = -log_preds.sum(dim=-1) loss = (1 - self.eps) * nll + self.eps * smooth_loss / num_classes return loss.mean() # 2. Custom classification head class Weights(nn.Module): def __init__(self, hidden_size: int = 4096, num_labels: int = 65): super().__init__() self.fc = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.SELU(), nn.Linear(hidden_size, hidden_size), nn.SELU(), nn.Linear(hidden_size, num_labels) ) def forward(self, x): return self.fc(x) # logits, not probabilities # 3. Modified LLaMA model class LlamaForSequenceClassificationWithCustomHead(LlamaForSequenceClassification): def __init__(self, config): super().__init__(config) self.weights = Weights(hidden_size=config.hidden_size, num_labels=65) self.loss_fn = LabelSmoothingCrossEntropy(eps=0.1) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] # [batch_size, seq_len, hidden_size] if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Padding token ID must be set for batch size > 1.") # Get sequence length index (last non-padding token) if self.config.pad_token_id is not None and input_ids is not None: sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(dim=-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] else: sequence_lengths = -1 # assume last token if pad token is undefined pooled_hidden = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] logits = self.weights(pooled_hidden) loss = None if labels is not None: labels = labels.to(logits.device) loss = self.loss_fn(logits, labels) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=logits, # shape [batch_size, 65] past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, )