AlexSychovUN commited on
Commit
1e0ff3f
·
1 Parent(s): 956b371

Added files

Browse files
Files changed (1) hide show
  1. transformer_from_scratch/model.py +109 -1
transformer_from_scratch/model.py CHANGED
@@ -62,4 +62,112 @@ class LayerNormalization(nn.Module):
62
  class FeedForwardBlock(nn.Module):
63
  def __init__(self, d_model: int, d_ff: int, dropout: float):
64
  super().__init__()
65
- self.linear1 = nn.Linear(d_model, d_ff)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  class FeedForwardBlock(nn.Module):
63
  def __init__(self, d_model: int, d_ff: int, dropout: float):
64
  super().__init__()
65
+ self.linear1 = nn.Linear(d_model, d_ff) # W1 and b1, bias = True
66
+ self.dropout = nn.Dropout(dropout)
67
+ self.linear2 = nn.Linear(d_ff, d_model) # W2 and b2, bias = True
68
+
69
+ def forward(self, x):
70
+ # (Batch, Seq_len, d_model) --> (Batch, Seq_len, d_ff) --> (Batch, Seq-len, d_model)
71
+ return self.linear2(self.dropout(torch.relu(self.linear1(x))))
72
+
73
+
74
+ class MultiHeadAttention(nn.Module):
75
+ def __init__(self, d_model: int, h: int, dropout: float):
76
+ super().__init__()
77
+ self.d_model = d_model
78
+ self.h = h
79
+
80
+ assert d_model % h == 0, "d_model must be divisible by h"
81
+
82
+ self.d_k = d_model // h
83
+ self.w_q = nn.Linear(d_model, d_model) # Wq
84
+ self.w_k = nn.Linear(d_model, d_model) # Wk
85
+ self.w_v = nn.Linear(d_model, d_model) # Wv
86
+
87
+ self.wo = nn.Linear(d_model, d_model) # Wo
88
+ self.dropout = nn.Dropout(dropout)
89
+
90
+ @staticmethod
91
+ def attention(query, key, value, mask, dropout: nn.Dropout):
92
+ d_k = query.size(-1)
93
+
94
+ # (Batch, h, seq_len, d_k) --> (Batch, h, seq_len, seq_len)
95
+ attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
96
+ if mask is not None:
97
+ attention_scores.masked_fill_(mask == 0, -1e9)
98
+ attention_scores = attention_scores.softmax(
99
+ dim=-1
100
+ ) # (Batch, h, seq_Len, seq_len)
101
+ if dropout is not None:
102
+ attention_scores = dropout(attention_scores)
103
+
104
+ return (attention_scores @ value), attention_scores
105
+
106
+ def forward(self, q, k, v, mask):
107
+ query = self.w_q(q) # (Batch, Seq_Len, d_model) --> (Batch, Seq_Len, d_model)
108
+ key = self.w_k(k) # (Batch, Seq_Len, d_model) --> (Batch, Seq_Len, d_model)
109
+ value = self.w_v(v) # (Batch, Seq_Len, d_model) --> (Batch, Seq_Len, d_model)
110
+
111
+ # (Batch, Seq_Len, d_model) --> (Batch, Seq_len, h, d_k) --> (Batch, h, Seq_len, d_k)
112
+ query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(
113
+ 1, 2
114
+ )
115
+ key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
116
+ value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(
117
+ 1, 2
118
+ )
119
+
120
+ x, attention_scores = MultiHeadAttention.attention(
121
+ query, key, value, mask, self.dropout
122
+ )
123
+
124
+ # (Batch, h, seq_len, d_k) --> (Batch, Seq_len, h, d_k) --> (Batch, Seq_len, d_model), contiguous - in place
125
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
126
+
127
+ # (Batch, Seq_len, d_model) --> (Batch, Seq_len, d_model)
128
+ return self.wo(x)
129
+
130
+
131
+ class ResidualConnection(nn.Module):
132
+ def __init__(self, dropout: float):
133
+ super().__init__()
134
+ self.dropout = nn.Dropout(dropout)
135
+ self.norm = LayerNormalization()
136
+
137
+ def forward(self, x, sublayer): # sublayer - the previous layer
138
+ return x + self.dropout(sublayer(self.norm(x)))
139
+
140
+
141
+ class EncoderBlock(nn.Module):
142
+ def __init__(
143
+ self,
144
+ self_attention_block: MultiHeadAttention,
145
+ feed_forward_block: FeedForwardBlock,
146
+ dropout: float,
147
+ ):
148
+ super().__init__()
149
+ self.self_attention_block = self_attention_block
150
+ self.feed_forward_block = feed_forward_block
151
+ self.residual_connections = nn.ModuleList(
152
+ [ResidualConnection(dropout) for _ in range(2)]
153
+ )
154
+
155
+ def forward(self, x, src_mask):
156
+ x = self.residual_connections[0](
157
+ x, lambda x: self.self_attention_block(x, x, x, src_mask)
158
+ )
159
+ x = self.residual_connections[1](x, self.feed_forward_block)
160
+ return x
161
+
162
+
163
+ class Encoder(nn.Module):
164
+ def __init__(self, layers: nn.ModuleList):
165
+ super().__init__()
166
+ self.layers = layers
167
+ self.norm = LayerNormalization()
168
+
169
+ def forward(self, x, mask):
170
+ for layer in self.layers:
171
+ x = layer(x, mask)
172
+
173
+ return self.norm(x)