Commit ·
077b6c3
1
Parent(s): b3705e9
Added files
Browse files
transformer_from_scratch/model.py
CHANGED
|
@@ -171,3 +171,111 @@ class Encoder(nn.Module):
|
|
| 171 |
x = layer(x, mask)
|
| 172 |
|
| 173 |
return self.norm(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
x = layer(x, mask)
|
| 172 |
|
| 173 |
return self.norm(x)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class DecoderBlock(nn.Module):
|
| 177 |
+
def __init__(self, self_attention_block: MultiHeadAttention, cross_attention_block: MultiHeadAttention, feed_forward_block: FeedForwardBlock, dropout: float):
|
| 178 |
+
super().__init__()
|
| 179 |
+
self.self_attention_block = self_attention_block
|
| 180 |
+
self.cross_attention_block = cross_attention_block
|
| 181 |
+
self.feed_forward_block = feed_forward_block
|
| 182 |
+
self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
|
| 183 |
+
|
| 184 |
+
# x - input of the decoder, src_mask - mask for encoder, tgt_mask - mask applied to the decoder
|
| 185 |
+
def forward(self, x, encoder_output, src_mask, tgt_mask):
|
| 186 |
+
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
|
| 187 |
+
x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
|
| 188 |
+
x = self.residual_connections[2](x, self.feed_forward_block)
|
| 189 |
+
return x
|
| 190 |
+
|
| 191 |
+
class Decoder(nn.Module):
|
| 192 |
+
def __init__(self, layers: nn.ModuleList):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.layers = layers
|
| 195 |
+
self.norm = LayerNormalization()
|
| 196 |
+
|
| 197 |
+
def forward(self, x, encoder_output, src_mask, tgt_mask):
|
| 198 |
+
for layer in self.layers:
|
| 199 |
+
x = layer(x, encoder_output, src_mask, tgt_mask)
|
| 200 |
+
return self.norm(x)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class ProjectionLayer(nn.Module):
|
| 204 |
+
def __init__(self, d_model: int, vocab_size: int):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.proj = nn.Linear(d_model, vocab_size)
|
| 207 |
+
|
| 208 |
+
def forward(self, x):
|
| 209 |
+
# (Batch, Seq_len, d_model) --> (Batch, Seq_len, Vocab_size)
|
| 210 |
+
return torch.log_softmax(self.proj(x), dim=-1)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class Transformer(nn.Module):
|
| 215 |
+
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer):
|
| 216 |
+
super().__init__()
|
| 217 |
+
self.encoder = encoder
|
| 218 |
+
self.decoder = decoder
|
| 219 |
+
self.src_embed = src_embed
|
| 220 |
+
self.tgt_embed = tgt_embed
|
| 221 |
+
self.src_pos = src_pos
|
| 222 |
+
self.tgt_pos = tgt_pos
|
| 223 |
+
self.projection_layer = projection_layer
|
| 224 |
+
|
| 225 |
+
def encode(self, src, src_mask):
|
| 226 |
+
src = self.src_embed(src)
|
| 227 |
+
src = self.src_pos(src)
|
| 228 |
+
return self.encoder(src, src_mask)
|
| 229 |
+
|
| 230 |
+
def decode(self, encoder_output, src_mask, tgt, tgt_mask):
|
| 231 |
+
tgt = self.tgt_embed(tgt)
|
| 232 |
+
tgt = self.tgt_pos(tgt)
|
| 233 |
+
return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
|
| 234 |
+
|
| 235 |
+
def project(self, x):
|
| 236 |
+
return self.projection_layer(x)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: int = 0.1, d_ff: int = 2048):
|
| 240 |
+
# Create the embedding layers
|
| 241 |
+
src_embed = InputEmbeddings(d_model, src_vocab_size)
|
| 242 |
+
tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
|
| 243 |
+
|
| 244 |
+
# Create the positional encoding layers
|
| 245 |
+
src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
|
| 246 |
+
tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
|
| 247 |
+
|
| 248 |
+
# Create the encoder blocks
|
| 249 |
+
encoder_blocks = []
|
| 250 |
+
for _ in range(N):
|
| 251 |
+
encoder_self_attention_block = MultiHeadAttention(d_model, h, dropout)
|
| 252 |
+
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
|
| 253 |
+
encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
|
| 254 |
+
encoder_blocks.append(encoder_block)
|
| 255 |
+
|
| 256 |
+
# Create the decoder blocks
|
| 257 |
+
decoder_blocks = []
|
| 258 |
+
for _ in range(N):
|
| 259 |
+
decoder_self_attention_block = MultiHeadAttention(d_model, h, dropout)
|
| 260 |
+
decoder_cross_attention_block = MultiHeadAttention(d_model, h, dropout)
|
| 261 |
+
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
|
| 262 |
+
decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
|
| 263 |
+
decoder_blocks.append(decoder_block)
|
| 264 |
+
|
| 265 |
+
# Create the encoder and decoder
|
| 266 |
+
encoder = Encoder(nn.ModuleList(encoder_blocks))
|
| 267 |
+
decoder = Decoder(nn.ModuleList(decoder_blocks))
|
| 268 |
+
|
| 269 |
+
# Create the projection layer
|
| 270 |
+
projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
|
| 271 |
+
|
| 272 |
+
# Build the transformer
|
| 273 |
+
transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
|
| 274 |
+
|
| 275 |
+
# Initialize the parameters
|
| 276 |
+
for p in transformer.parameters():
|
| 277 |
+
if p.dim() > 1:
|
| 278 |
+
nn.init.xavier_uniform_(p)
|
| 279 |
+
|
| 280 |
+
return transformer
|
| 281 |
+
|