opinder2906 commited on
Commit
2bace3f
·
verified ·
1 Parent(s): b9797cb

Update src/model_def.py

Browse files
Files changed (1) hide show
  1. src/model_def.py +11 -11
src/model_def.py CHANGED
@@ -2,9 +2,10 @@ import torch
2
  import torch.nn as nn
3
  import numpy as np
4
 
5
- # Positional encoding for Transformer
 
6
  class PositionalEncoding(nn.Module):
7
- def __init__(self, d_model, max_len=5000):
8
  super().__init__()
9
  pe = torch.zeros(max_len, d_model)
10
  position = torch.arange(0, max_len).unsqueeze(1)
@@ -14,23 +15,22 @@ class PositionalEncoding(nn.Module):
14
  self.pe = pe.unsqueeze(0)
15
 
16
  def forward(self, x):
17
- x = x + self.pe[:, : x.size(1)]
18
- return x
19
 
20
- # Transformer-based classifier
21
- authors@article not relevant
22
  class EmotionTransformer(nn.Module):
23
- def __init__(self, vocab_size, embed_dim, num_heads, num_classes, dropout=0.1):
24
  super().__init__()
25
- self.embedding = nn.Embedding(vocab_size, embed_dim)
26
  self.pos_encoder = PositionalEncoding(embed_dim)
27
- encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads)
28
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
29
- self.dropout = nn.Dropout(dropout)
30
  self.fc = nn.Linear(embed_dim, num_classes)
31
 
32
  def forward(self, x):
33
- mask = (x == 0) # pad index = 0
34
  x = self.embedding(x)
35
  x = self.pos_encoder(x)
36
  x = self.transformer(x, src_key_padding_mask=mask)
 
2
  import torch.nn as nn
3
  import numpy as np
4
 
5
+ # Positional Encoding class
6
+ def class (rename to PositionalEncoding)
7
  class PositionalEncoding(nn.Module):
8
+ def __init__(self, d_model, max_len=32):
9
  super().__init__()
10
  pe = torch.zeros(max_len, d_model)
11
  position = torch.arange(0, max_len).unsqueeze(1)
 
15
  self.pe = pe.unsqueeze(0)
16
 
17
  def forward(self, x):
18
+ return x + self.pe[:, :x.size(1)].to(x.device)
 
19
 
20
+ # Transformer emotion classifier
21
+ def class (rename to EmotionTransformer)
22
  class EmotionTransformer(nn.Module):
23
+ def __init__(self, vocab_size, embed_dim=64, num_heads=4, num_classes=None):
24
  super().__init__()
25
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
26
  self.pos_encoder = PositionalEncoding(embed_dim)
27
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
28
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
29
+ self.dropout = nn.Dropout(0.3)
30
  self.fc = nn.Linear(embed_dim, num_classes)
31
 
32
  def forward(self, x):
33
+ mask = (x == 0)
34
  x = self.embedding(x)
35
  x = self.pos_encoder(x)
36
  x = self.transformer(x, src_key_padding_mask=mask)