import torch import torch.nn as nn from transformers import PreTrainedModel from .configuration_ESPFormer import ESPFormerConfig class ESPFormerModel(PreTrainedModel): config_class = ESPFormerConfig def __init__(self,config): super().__init__(config) self.embed_dim = config.embed_dim self.seq_length = config.seq_length self.embeddings = nn.Linear(config.input_dim, config.embed_dim) nn.init.xavier_uniform_(self.embeddings.weight) if self.embeddings.bias is not None: nn.init.zeros_(self.embeddings.bias) self.cls_token = nn.Parameter(torch.zeros(1,1,config.embed_dim)) # it set requires_grad=True automaticaly nn.init.normal_(self.cls_token, std=0.02) self. positional_embeddings = nn.Parameter( torch.zeros(1, config.seq_length+1, config.embed_dim) ) # we can also use torch.randn(1,config.seq_length+1, config.embed_dim) and not use nn.init.normal() code nn.init.normal_(self. positional_embeddings, std=0.02) #Tranformer Encoder #encoder layer encoder_layer = nn.TransformerEncoderLayer( d_model = config.embed_dim, nhead = config.num_heads, dim_feedforward = config.embed_dim*4, dropout = config.dropout, activation ='gelu', batch_first = True )# It automaticaly intialze the parameters self.transformer_encoder = nn.TransformerEncoder( encoder_layer, num_layers = config.num_layers ) #classification head self.classifier = nn.Sequential( nn.Linear(config.embed_dim, config.embed_dim//2), nn.ReLU(), nn.Dropout(config.dropout), nn.Linear(config.embed_dim//2, config.num_classes) ) def forward(self, input_ids=None, attention_mask=None, labels=None): """ input_ids = Tensor of shape (batch_size, seq_length, input_dims) labels = Tensor of shape (batch_size,) """ batch_size = input_ids.size(0) #Embeddings x = self.embeddings(input_ids) #Add CLS token cls_tokens = self.cls_token.expand(batch_size,-1,-1) x = torch.cat((cls_tokens,x),dim=1) # Add positional embeddings x = x + self.positional_embeddings x = self.transformer_encoder(x) cls_output = x[:,0,:] logits = self.classifier(cls_output) loss = None if labels is not None: loss_fn = nn.CrossEntropyLoss() loss = loss_fn(logits, labels) output = (logits,) return ((loss,) + output) if loss is not None else output