map_65label_v4 / modeling_custom.py
qwertyuiopasdfg's picture
Upload folder using huggingface_hub
ffe372d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union, Tuple, List
from transformers import LlamaForSequenceClassification, Cache
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
# 1. Label smoothing cross entropy loss
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, eps: float = 0.1):
super().__init__()
self.eps = eps
def forward(self, pred, target):
num_classes = pred.size(-1)
log_preds = F.log_softmax(pred, dim=-1)
nll = F.nll_loss(log_preds, target, reduction='none')
smooth_loss = -log_preds.sum(dim=-1)
loss = (1 - self.eps) * nll + self.eps * smooth_loss / num_classes
return loss.mean()
# 2. Custom classification head
class Weights(nn.Module):
def __init__(self, hidden_size: int = 4096, num_labels: int = 65):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.SELU(),
nn.Linear(hidden_size, hidden_size),
nn.SELU(),
nn.Linear(hidden_size, num_labels)
)
def forward(self, x):
return self.fc(x) # logits, not probabilities
# 3. Modified LLaMA model
class LlamaForSequenceClassificationWithCustomHead(LlamaForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.weights = Weights(hidden_size=config.hidden_size, num_labels=65)
self.loss_fn = LabelSmoothingCrossEntropy(eps=0.1)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0] # [batch_size, seq_len, hidden_size]
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Padding token ID must be set for batch size > 1.")
# Get sequence length index (last non-padding token)
if self.config.pad_token_id is not None and input_ids is not None:
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(dim=-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
else:
sequence_lengths = -1 # assume last token if pad token is undefined
pooled_hidden = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths]
logits = self.weights(pooled_hidden)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss = self.loss_fn(logits, labels)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=logits, # shape [batch_size, 65]
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)