ESPFormer / modeling_ESPFormer.py
parasparani's picture
Upload model
5afaf54 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_ESPFormer import ESPFormerConfig
class ESPFormerModel(PreTrainedModel):
config_class = ESPFormerConfig
def __init__(self,config):
super().__init__(config)
self.embed_dim = config.embed_dim
self.seq_length = config.seq_length
self.embeddings = nn.Linear(config.input_dim, config.embed_dim)
nn.init.xavier_uniform_(self.embeddings.weight)
if self.embeddings.bias is not None:
nn.init.zeros_(self.embeddings.bias)
self.cls_token = nn.Parameter(torch.zeros(1,1,config.embed_dim)) # it set requires_grad=True automaticaly
nn.init.normal_(self.cls_token, std=0.02)
self. positional_embeddings = nn.Parameter(
torch.zeros(1, config.seq_length+1, config.embed_dim)
) # we can also use torch.randn(1,config.seq_length+1, config.embed_dim) and not use nn.init.normal() code
nn.init.normal_(self. positional_embeddings, std=0.02)
#Tranformer Encoder
#encoder layer
encoder_layer = nn.TransformerEncoderLayer(
d_model = config.embed_dim,
nhead = config.num_heads,
dim_feedforward = config.embed_dim*4,
dropout = config.dropout,
activation ='gelu',
batch_first = True
)# It automaticaly intialze the parameters
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer,
num_layers = config.num_layers
)
#classification head
self.classifier = nn.Sequential(
nn.Linear(config.embed_dim, config.embed_dim//2),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(config.embed_dim//2, config.num_classes)
)
def forward(self, input_ids=None, attention_mask=None, labels=None):
"""
input_ids = Tensor of shape (batch_size, seq_length, input_dims)
labels = Tensor of shape (batch_size,)
"""
batch_size = input_ids.size(0)
#Embeddings
x = self.embeddings(input_ids)
#Add CLS token
cls_tokens = self.cls_token.expand(batch_size,-1,-1)
x = torch.cat((cls_tokens,x),dim=1)
# Add positional embeddings
x = x + self.positional_embeddings
x = self.transformer_encoder(x)
cls_output = x[:,0,:]
logits = self.classifier(cls_output)
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
output = (logits,)
return ((loss,) + output) if loss is not None else output