AlexSychovUN commited on
Commit
077b6c3
·
1 Parent(s): b3705e9

Added files

Browse files
Files changed (1) hide show
  1. transformer_from_scratch/model.py +108 -0
transformer_from_scratch/model.py CHANGED
@@ -171,3 +171,111 @@ class Encoder(nn.Module):
171
  x = layer(x, mask)
172
 
173
  return self.norm(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  x = layer(x, mask)
172
 
173
  return self.norm(x)
174
+
175
+
176
+ class DecoderBlock(nn.Module):
177
+ def __init__(self, self_attention_block: MultiHeadAttention, cross_attention_block: MultiHeadAttention, feed_forward_block: FeedForwardBlock, dropout: float):
178
+ super().__init__()
179
+ self.self_attention_block = self_attention_block
180
+ self.cross_attention_block = cross_attention_block
181
+ self.feed_forward_block = feed_forward_block
182
+ self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
183
+
184
+ # x - input of the decoder, src_mask - mask for encoder, tgt_mask - mask applied to the decoder
185
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
186
+ x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
187
+ x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
188
+ x = self.residual_connections[2](x, self.feed_forward_block)
189
+ return x
190
+
191
+ class Decoder(nn.Module):
192
+ def __init__(self, layers: nn.ModuleList):
193
+ super().__init__()
194
+ self.layers = layers
195
+ self.norm = LayerNormalization()
196
+
197
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
198
+ for layer in self.layers:
199
+ x = layer(x, encoder_output, src_mask, tgt_mask)
200
+ return self.norm(x)
201
+
202
+
203
+ class ProjectionLayer(nn.Module):
204
+ def __init__(self, d_model: int, vocab_size: int):
205
+ super().__init__()
206
+ self.proj = nn.Linear(d_model, vocab_size)
207
+
208
+ def forward(self, x):
209
+ # (Batch, Seq_len, d_model) --> (Batch, Seq_len, Vocab_size)
210
+ return torch.log_softmax(self.proj(x), dim=-1)
211
+
212
+
213
+
214
+ class Transformer(nn.Module):
215
+ def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer):
216
+ super().__init__()
217
+ self.encoder = encoder
218
+ self.decoder = decoder
219
+ self.src_embed = src_embed
220
+ self.tgt_embed = tgt_embed
221
+ self.src_pos = src_pos
222
+ self.tgt_pos = tgt_pos
223
+ self.projection_layer = projection_layer
224
+
225
+ def encode(self, src, src_mask):
226
+ src = self.src_embed(src)
227
+ src = self.src_pos(src)
228
+ return self.encoder(src, src_mask)
229
+
230
+ def decode(self, encoder_output, src_mask, tgt, tgt_mask):
231
+ tgt = self.tgt_embed(tgt)
232
+ tgt = self.tgt_pos(tgt)
233
+ return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
234
+
235
+ def project(self, x):
236
+ return self.projection_layer(x)
237
+
238
+
239
+ def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: int = 0.1, d_ff: int = 2048):
240
+ # Create the embedding layers
241
+ src_embed = InputEmbeddings(d_model, src_vocab_size)
242
+ tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
243
+
244
+ # Create the positional encoding layers
245
+ src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
246
+ tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
247
+
248
+ # Create the encoder blocks
249
+ encoder_blocks = []
250
+ for _ in range(N):
251
+ encoder_self_attention_block = MultiHeadAttention(d_model, h, dropout)
252
+ feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
253
+ encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
254
+ encoder_blocks.append(encoder_block)
255
+
256
+ # Create the decoder blocks
257
+ decoder_blocks = []
258
+ for _ in range(N):
259
+ decoder_self_attention_block = MultiHeadAttention(d_model, h, dropout)
260
+ decoder_cross_attention_block = MultiHeadAttention(d_model, h, dropout)
261
+ feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
262
+ decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
263
+ decoder_blocks.append(decoder_block)
264
+
265
+ # Create the encoder and decoder
266
+ encoder = Encoder(nn.ModuleList(encoder_blocks))
267
+ decoder = Decoder(nn.ModuleList(decoder_blocks))
268
+
269
+ # Create the projection layer
270
+ projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
271
+
272
+ # Build the transformer
273
+ transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
274
+
275
+ # Initialize the parameters
276
+ for p in transformer.parameters():
277
+ if p.dim() > 1:
278
+ nn.init.xavier_uniform_(p)
279
+
280
+ return transformer
281
+