|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel |
|
|
from .configuration_ESPFormer import ESPFormerConfig |
|
|
|
|
|
class ESPFormerModel(PreTrainedModel): |
|
|
config_class = ESPFormerConfig |
|
|
def __init__(self,config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.embed_dim = config.embed_dim |
|
|
self.seq_length = config.seq_length |
|
|
|
|
|
self.embeddings = nn.Linear(config.input_dim, config.embed_dim) |
|
|
nn.init.xavier_uniform_(self.embeddings.weight) |
|
|
if self.embeddings.bias is not None: |
|
|
nn.init.zeros_(self.embeddings.bias) |
|
|
self.cls_token = nn.Parameter(torch.zeros(1,1,config.embed_dim)) |
|
|
nn.init.normal_(self.cls_token, std=0.02) |
|
|
self. positional_embeddings = nn.Parameter( |
|
|
torch.zeros(1, config.seq_length+1, config.embed_dim) |
|
|
) |
|
|
nn.init.normal_(self. positional_embeddings, std=0.02) |
|
|
|
|
|
|
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
|
d_model = config.embed_dim, |
|
|
nhead = config.num_heads, |
|
|
dim_feedforward = config.embed_dim*4, |
|
|
dropout = config.dropout, |
|
|
activation ='gelu', |
|
|
batch_first = True |
|
|
|
|
|
) |
|
|
|
|
|
self.transformer_encoder = nn.TransformerEncoder( |
|
|
encoder_layer, |
|
|
num_layers = config.num_layers |
|
|
) |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(config.embed_dim, config.embed_dim//2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(config.dropout), |
|
|
nn.Linear(config.embed_dim//2, config.num_classes) |
|
|
) |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, labels=None): |
|
|
""" |
|
|
input_ids = Tensor of shape (batch_size, seq_length, input_dims) |
|
|
labels = Tensor of shape (batch_size,) |
|
|
""" |
|
|
batch_size = input_ids.size(0) |
|
|
|
|
|
|
|
|
x = self.embeddings(input_ids) |
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(batch_size,-1,-1) |
|
|
x = torch.cat((cls_tokens,x),dim=1) |
|
|
|
|
|
|
|
|
x = x + self.positional_embeddings |
|
|
|
|
|
x = self.transformer_encoder(x) |
|
|
|
|
|
cls_output = x[:,0,:] |
|
|
|
|
|
logits = self.classifier(cls_output) |
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fn = nn.CrossEntropyLoss() |
|
|
loss = loss_fn(logits, labels) |
|
|
|
|
|
output = (logits,) |
|
|
return ((loss,) + output) if loss is not None else output |