prokbert-mini-long / models.py
3v324v23's picture
Adding curriculum face model
acbb45f
# coding=utf-8
import warnings
import logging
from typing import Optional, Tuple, Union
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.utils.hub import cached_file
import math
#from prokbert.training_utils import compute_metrics_eval_prediction
def l2_norm(input, axis=1, epsilon=1e-12):
norm = torch.norm(input, 2, axis, True)
norm = torch.clamp(norm, min=epsilon) # Avoid zero division
output = torch.div(input, norm)
return output
def initialize_linear_kaiming(layer: nn.Linear):
if isinstance(layer, nn.Linear):
nn.init.kaiming_uniform_(layer.weight, nonlinearity='linear')
if layer.bias is not None:
nn.init.zeros_(layer.bias)
class BertForBinaryClassificationWithPooling(nn.Module):
"""
ProkBERT model for binary classification with custom pooling.
This model extends a pre-trained `MegatronBertModel` by adding a weighting layer
to compute a weighted sum over the sequence outputs, followed by a classifier.
Attributes:
base_model (MegatronBertModel): The base BERT model.
weighting_layer (nn.Linear): Linear layer to compute weights for each token.
dropout (nn.Dropout): Dropout layer.
classifier (nn.Linear): Linear layer for classification.
"""
def __init__(self, base_model: MegatronBertModel):
"""
Initialize the BertForBinaryClassificationWithPooling model.
Args:
base_model (MegatronBertModel): A pre-trained `MegatronBertModel` instance.
"""
super(BertForBinaryClassificationWithPooling, self).__init__()
self.base_model = base_model
self.base_model_config_dict = base_model.config.to_dict()
self.hidden_size = self.base_model_config_dict['hidden_size']
self.dropout_rate = self.base_model_config_dict['hidden_dropout_prob']
self.weighting_layer = nn.Linear(self.hidden_size, 1)
self.dropout = nn.Dropout(self.dropout_rate)
self.classifier = nn.Linear(self.hidden_size, 2)
def forward(self, input_ids, attention_mask=None, labels=None, output_hidden_states=False, output_pooled_output=False):
# Modified call to base model to include output_hidden_states
outputs = self.base_model(input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states)
sequence_output = outputs[0]
# Compute weights for each position in the sequence
weights = self.weighting_layer(sequence_output)
weights = torch.nn.functional.softmax(weights, dim=1)
# Compute weighted sum
pooled_output = torch.sum(weights * sequence_output, dim=1)
# Classification head
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
# Prepare the output as a dictionary
output = {"logits": logits}
# Include hidden states in output if requested
if output_hidden_states:
output["hidden_states"] = outputs.hidden_states
if output_pooled_output:
output["pooled_output"] = pooled_output
# If labels are provided, compute the loss
if labels is not None:
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, 2), labels.view(-1))
output["loss"] = loss
return output
def save_pretrained(self, save_directory):
"""
Save the model weights and configuration in a directory.
Args:
save_directory (str): Directory where the model and configuration can be saved.
"""
print('The save pretrained is called!')
if not os.path.exists(save_directory):
os.makedirs(save_directory)
model_path = os.path.join(save_directory, "pytorch_model.bin")
torch.save(self.state_dict(), model_path)
print(f'The save directory is: {save_directory}')
self.base_model.config.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Load the model weights and configuration from a local directory or Hugging Face Hub.
Args:
pretrained_model_name_or_path (str): Directory path where the model and configuration were saved, or name of the model in Hugging Face Hub.
Returns:
model: Instance of BertForBinaryClassificationWithPooling.
"""
# Determine if the path is local or from Hugging Face Hub
if os.path.exists(pretrained_model_name_or_path):
# Path is local
if 'config' in kwargs:
print('Config is in the parameters')
config = kwargs['config']
else:
config = MegatronBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
base_model = MegatronBertModel(config=config)
model = cls(base_model=base_model)
model_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
else:
# Path is from Hugging Face Hub
config = kwargs.pop('config', None)
if config is None:
config = MegatronBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
base_model = MegatronBertModel(config=config)
model = cls(base_model=base_model)
model_file = cached_file(pretrained_model_name_or_path, "pytorch_model.bin")
model.load_state_dict(torch.load(model_file, map_location=torch.device('cpu'), weights_only=True))
return model
class ProkBertConfig(MegatronBertConfig):
model_type = "prokbert"
def __init__(
self,
kmer: int = 6,
shift: int = 1,
num_class_labels: int = 2,
classification_dropout_rate: float = 0.1,
**kwargs,
):
super().__init__(**kwargs)
self.kmer = kmer
self.shift = shift
self.num_class_labels = num_class_labels
self.classification_dropout_rate = classification_dropout_rate
class ProkBertConfigCurr(ProkBertConfig):
model_type = "prokbert"
def __init__(
self,
bert_base_model = "neuralbioinfo/prokbert-mini",
curricular_face_m = 0.5,
curricular_face_s=64.,
curricular_num_labels = 2,
curriculum_hidden_size = -1,
classification_dropout_rate = 0.0,
**kwargs,
):
super().__init__( **kwargs)
self.curricular_num_labels = curricular_num_labels
self.curricular_face_m = curricular_face_m
self.curricular_face_s = curricular_face_s
self.bert_base_model = bert_base_model
self.curriculum_hidden_size = curriculum_hidden_size
self.classification_dropout_rate = classification_dropout_rate
class ProkBertClassificationConfig(ProkBertConfig):
model_type = "prokbert"
def __init__(
self,
num_labels: int = 2,
classification_dropout_rate: float = 0.1,
**kwargs,
):
super().__init__(**kwargs)
# Ide jön majd némi extra lépés, egyelőre csak próbálkozunk a sima configgal.
self.num_labels = num_labels
self.classification_dropout_rate = classification_dropout_rate
class ProkBertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ProkBertConfig
base_model_prefix = "bert"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class ProkBertModel(MegatronBertModel):
config_class = ProkBertConfig
def __init__(self, config: ProkBertConfig, **kwargs):
if not isinstance(config, ProkBertConfig):
raise ValueError(f"Expected `ProkBertConfig`, got {config.__class__.__module__}.{config.__class__.__name__}")
super().__init__(config, **kwargs)
self.config = config
# One should check if it is a prper prokbert config, if not crafting one.
class ProkBertForMaskedLM(MegatronBertForMaskedLM):
config_class = ProkBertConfig
def __init__(self, config: ProkBertConfig, **kwargs):
if not isinstance(config, ProkBertConfig):
raise ValueError(f"Expected `ProkBertConfig`, got {config.__class__.__module__}.{config.__class__.__name__}")
super().__init__(config, **kwargs)
self.config = config
# One should check if it is a prper prokbert config, if not crafting one.
class ProkBertForSequenceClassification(ProkBertPreTrainedModel):
config_class = ProkBertConfig
base_model_prefix = "bert"
def __init__(self, config):
super().__init__(config)
self.config = config
self.bert = ProkBertModel(config)
self.weighting_layer = nn.Linear(self.config.hidden_size, 1)
self.dropout = nn.Dropout(self.config.classification_dropout_rate)
self.classifier = nn.Linear(self.config.hidden_size, self.config.num_class_labels)
self.loss_fct = torch.nn.CrossEntropyLoss()
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_class_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_class_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
# Compute weights for each position in the sequence
weights = self.weighting_layer(sequence_output)
weights = torch.nn.functional.softmax(weights, dim=1)
# Compute weighted sum
pooled_output = torch.sum(weights * sequence_output, dim=1)
# Classification head
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss = self.loss_fct(logits.view(-1, self.config.num_class_labels), labels.view(-1))
classification_output = SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return classification_output
class CurricularFace(nn.Module):
def __init__(self, in_features, out_features, m=0.5, s=64.):
super(CurricularFace, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.m = m
self.s = s
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.threshold = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
self.kernel = Parameter(torch.Tensor(in_features, out_features))
self.register_buffer('t', torch.zeros(1))
def forward(self, embeddings, label):
# Normalize embeddings and the classifier kernel
embeddings = l2_norm(embeddings, axis=1)
kernel_norm = l2_norm(self.kernel, axis=0)
# Compute cosine similarity between embeddings and kernel columns
cos_theta = torch.mm(embeddings, kernel_norm)
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
# print(f"cos theta")
# print(cos_theta)
# Clone original cosine values (used later for analysis if needed)
with torch.no_grad():
origin_cos = cos_theta.clone()
# Get the cosine values corresponding to the ground-truth classes
target_logit = cos_theta[torch.arange(0, embeddings.size(0)), label].view(-1, 1)
sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target + margin)
# Create a mask for positions where the cosine similarity exceeds the modified value
mask = (cos_theta > cos_theta_m) #.to(dtype=torch.uint8)
# Apply the margin condition: for values greater than threshold, use cosine with margin;
# otherwise subtract a fixed term.
final_target_logit = torch.where(target_logit > self.threshold,
cos_theta_m,
target_logit - self.mm)
# Update the buffer 't' (used to control the weight of hard examples)
with torch.no_grad():
self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
# For the positions in the mask, re-scale the logits
try:
hard_example = cos_theta[mask]
except Exception as e:
print("Label max")
print(torch.max(label))
print("Shapes:")
print(embeddings.shape)
print(label.shape)
hard_example = cos_theta[mask]
cos_theta[mask] = hard_example * (self.t + hard_example)
# Replace the logits of the target classes with the modified target logit
final_target_logit = final_target_logit.to(cos_theta.dtype)
cos_theta.scatter_(1, label.view(-1, 1).long(), final_target_logit)
output = cos_theta * self.s
return output, origin_cos * self.s
class ProkBertForCurricularClassification(ProkBertPreTrainedModel):
config_class = ProkBertConfigCurr
base_model_prefix = "bert"
def __init__(self, config):
super().__init__(config)
self.config = config
self.bert = ProkBertModel(config)
# A weighting layer for pooling the sequence output
self.weighting_layer = nn.Linear(self.config.hidden_size, 1)
self.dropout = nn.Dropout(self.config.classification_dropout_rate)
if config.curriculum_hidden_size != -1:
self.linear = nn.Linear(self.config.hidden_size, config.curriculum_hidden_size)
# Replace the simple classifier with the CurricularFace head.
# Defaults m=0.5 and s=64 are used, but these can be adjusted if needed.
self.curricular_face = CurricularFace(config.curriculum_hidden_size,
self.config.curricular_num_labels,
m=self.config.curricular_face_m,
s=self.config.curricular_face_s)
else:
self.linear = nn.Identity()
self.curricular_face = CurricularFace(self.config.hidden_size,
self.config.curricular_num_labels,
m=self.config.curricular_face_m,
s=self.config.curricular_face_s)
self.loss_fct = torch.nn.CrossEntropyLoss()
self.post_init()
def _init_weights(self, module: nn.Module):
# first let the base class init everything else
super()._init_weights(module)
# then catch our pooling head and zero it
if module is getattr(self, "weighting_layer", None):
nn.init.xavier_uniform_(module.weight)
nn.init.zeros_(module.bias)
if module is getattr(self, "linear", None):
initialize_linear_kaiming(module)
if module is getattr(self, "curricular_face", None):
nn.init.kaiming_uniform_(module.kernel, a=math.sqrt(self.config.curricular_num_labels))
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Get the outputs from the base ProkBert model
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0] # (batch_size, seq_length, hidden_size)
# Pool the sequence output using a learned weighting (attention-like)
weights = self.weighting_layer(sequence_output) # (batch_size, seq_length, 1)
# Ensure mask shape matches
if attention_mask.dim() == 2:
mask = attention_mask
elif attention_mask.dim() == 4:
mask = attention_mask.squeeze(1).squeeze(1) # (batch_size, seq_length)
else:
raise ValueError(f"Unexpected attention_mask shape {attention_mask.shape}")
# Apply mask (masked positions -> -inf before softmax)
weights = weights.masked_fill(mask.unsqueeze(-1) == 0, float('-inf'))
# Normalize
weights = torch.nn.functional.softmax(weights, dim=1) # (batch_size, seq_length)
# Weighted pooling
#weights = weights.unsqueeze(-1) # (batch_size, seq_length, 1)
pooled_output = torch.sum(weights * sequence_output, dim=1) # (batch_size, hidden_size)
# Classifier head
pooled_output = self.dropout(pooled_output)
pooled_output = self.linear(pooled_output)
# CurricularFace requires the embeddings and the corresponding labels.
# Note: During inference (labels is None), we just return l2 norm of bert part of the model
if labels is None:
return l2_norm(pooled_output, axis = 1)
else:
logits, origin_cos = self.curricular_face(pooled_output, labels)
loss = None
if labels is not None:
loss = self.loss_fct(logits, labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)