File size: 2,718 Bytes
5afaf54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_ESPFormer import ESPFormerConfig

class ESPFormerModel(PreTrainedModel):
    config_class = ESPFormerConfig
    def __init__(self,config):
        super().__init__(config)

        self.embed_dim = config.embed_dim
        self.seq_length = config.seq_length

        self.embeddings = nn.Linear(config.input_dim, config.embed_dim)
        nn.init.xavier_uniform_(self.embeddings.weight)
        if self.embeddings.bias is not None:
            nn.init.zeros_(self.embeddings.bias)
        self.cls_token = nn.Parameter(torch.zeros(1,1,config.embed_dim)) # it set requires_grad=True automaticaly
        nn.init.normal_(self.cls_token, std=0.02)
        self. positional_embeddings = nn.Parameter(
            torch.zeros(1, config.seq_length+1, config.embed_dim)
        ) # we can also use torch.randn(1,config.seq_length+1, config.embed_dim) and not use nn.init.normal() code
        nn.init.normal_(self. positional_embeddings, std=0.02)

        #Tranformer Encoder
        #encoder layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model = config.embed_dim,
            nhead = config.num_heads,
            dim_feedforward = config.embed_dim*4,
            dropout = config.dropout,
            activation ='gelu',
            batch_first = True
            
        )# It automaticaly intialze the parameters

        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers = config.num_layers
        )

        #classification head
        self.classifier = nn.Sequential(
            nn.Linear(config.embed_dim, config.embed_dim//2),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.embed_dim//2, config.num_classes)
        )

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        """
        input_ids = Tensor of shape (batch_size, seq_length, input_dims)
        labels = Tensor of shape (batch_size,)
        """
        batch_size = input_ids.size(0)

        #Embeddings
        x = self.embeddings(input_ids)

        #Add CLS token
        cls_tokens = self.cls_token.expand(batch_size,-1,-1)
        x = torch.cat((cls_tokens,x),dim=1)

        # Add positional embeddings
        x = x + self.positional_embeddings

        x = self.transformer_encoder(x)

        cls_output = x[:,0,:]

        logits = self.classifier(cls_output)
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)

        output = (logits,)
        return ((loss,) + output) if loss is not None else output