ewdlop commited on
Commit
fd5a74f
·
verified ·
1 Parent(s): 3866726

Add Shakespeare Transformer model - model definition

Browse files
Files changed (1) hide show
  1. encoder_decoder_transformer.py +283 -283
encoder_decoder_transformer.py CHANGED
@@ -1,284 +1,284 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import math
5
-
6
- class PositionalEncoding(nn.Module):
7
- def __init__(self, d_model, max_len=5000):
8
- super().__init__()
9
-
10
- pe = torch.zeros(max_len, d_model)
11
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
12
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
13
-
14
- pe[:, 0::2] = torch.sin(position * div_term)
15
- pe[:, 1::2] = torch.cos(position * div_term)
16
- pe = pe.unsqueeze(0).transpose(0, 1)
17
-
18
- self.register_buffer('pe', pe)
19
-
20
- def forward(self, x):
21
- return x + self.pe[:x.size(0), :]
22
-
23
- class FeedForward(nn.Module):
24
- def __init__(self, d_model, d_ff):
25
- super().__init__()
26
- self.linear1 = nn.Linear(d_model, d_ff)
27
- self.linear2 = nn.Linear(d_ff, d_model)
28
-
29
- def forward(self, x):
30
- return self.linear2(F.relu(self.linear1(x)))
31
-
32
- class EncoderLayer(nn.Module):
33
- def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
34
- super().__init__()
35
- # Using PyTorch's built-in MultiheadAttention
36
- self.self_attention = nn.MultiheadAttention(
37
- embed_dim=d_model,
38
- num_heads=n_heads,
39
- dropout=dropout,
40
- batch_first=False # PyTorch default: (seq_len, batch, embed_dim)
41
- )
42
- self.feed_forward = FeedForward(d_model, d_ff)
43
- self.norm1 = nn.LayerNorm(d_model)
44
- self.norm2 = nn.LayerNorm(d_model)
45
- self.dropout = nn.Dropout(dropout)
46
-
47
- def forward(self, x, key_padding_mask=None):
48
- # x shape: (seq_len, batch_size, d_model)
49
-
50
- # Multi-head self-attention with residual connection and layer norm
51
- attn_output, _ = self.self_attention(
52
- query=x,
53
- key=x,
54
- value=x,
55
- key_padding_mask=key_padding_mask,
56
- need_weights=False
57
- )
58
- x = self.norm1(x + self.dropout(attn_output))
59
-
60
- # Feed forward with residual connection and layer norm
61
- ff_output = self.feed_forward(x)
62
- x = self.norm2(x + self.dropout(ff_output))
63
-
64
- return x
65
-
66
- class DecoderLayer(nn.Module):
67
- def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
68
- super().__init__()
69
- # Masked self-attention
70
- self.masked_self_attention = nn.MultiheadAttention(
71
- embed_dim=d_model,
72
- num_heads=n_heads,
73
- dropout=dropout,
74
- batch_first=False
75
- )
76
- # Cross-attention (decoder attends to encoder)
77
- self.cross_attention = nn.MultiheadAttention(
78
- embed_dim=d_model,
79
- num_heads=n_heads,
80
- dropout=dropout,
81
- batch_first=False
82
- )
83
- self.feed_forward = FeedForward(d_model, d_ff)
84
- self.norm1 = nn.LayerNorm(d_model)
85
- self.norm2 = nn.LayerNorm(d_model)
86
- self.norm3 = nn.LayerNorm(d_model)
87
- self.dropout = nn.Dropout(dropout)
88
-
89
- def forward(self, x, enc_output, tgt_mask=None, memory_key_padding_mask=None, tgt_key_padding_mask=None):
90
- # x shape: (tgt_seq_len, batch_size, d_model)
91
- # enc_output shape: (src_seq_len, batch_size, d_model)
92
-
93
- # Masked multi-head self-attention
94
- attn_output, _ = self.masked_self_attention(
95
- query=x,
96
- key=x,
97
- value=x,
98
- attn_mask=tgt_mask,
99
- key_padding_mask=tgt_key_padding_mask,
100
- need_weights=False
101
- )
102
- x = self.norm1(x + self.dropout(attn_output))
103
-
104
- # Multi-head cross-attention (decoder attends to encoder)
105
- attn_output, _ = self.cross_attention(
106
- query=x,
107
- key=enc_output,
108
- value=enc_output,
109
- key_padding_mask=memory_key_padding_mask,
110
- need_weights=False
111
- )
112
- x = self.norm2(x + self.dropout(attn_output))
113
-
114
- # Feed forward
115
- ff_output = self.feed_forward(x)
116
- x = self.norm3(x + self.dropout(ff_output))
117
-
118
- return x
119
-
120
- class Transformer(nn.Module):
121
- def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_heads=8,
122
- n_encoder_layers=6, n_decoder_layers=6, d_ff=2048, dropout=0.1, pad_idx=0):
123
- super().__init__()
124
-
125
- self.d_model = d_model
126
- self.pad_idx = pad_idx
127
-
128
- # Embeddings
129
- self.src_embedding = nn.Embedding(src_vocab_size, d_model, padding_idx=pad_idx)
130
- self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model, padding_idx=pad_idx)
131
-
132
- # Positional encodings
133
- self.pos_encoding = PositionalEncoding(d_model)
134
-
135
- # Encoder layers
136
- self.encoder_layers = nn.ModuleList([
137
- EncoderLayer(d_model, n_heads, d_ff, dropout)
138
- for _ in range(n_encoder_layers)
139
- ])
140
-
141
- # Decoder layers
142
- self.decoder_layers = nn.ModuleList([
143
- DecoderLayer(d_model, n_heads, d_ff, dropout)
144
- for _ in range(n_decoder_layers)
145
- ])
146
-
147
- # Output projection
148
- self.linear = nn.Linear(d_model, tgt_vocab_size)
149
- self.dropout = nn.Dropout(dropout)
150
-
151
- # Initialize weights
152
- self._init_weights()
153
-
154
- def _init_weights(self):
155
- for p in self.parameters():
156
- if p.dim() > 1:
157
- nn.init.xavier_uniform_(p)
158
-
159
- def create_padding_mask(self, seq):
160
- """Create padding mask for sequences (True for padding tokens)"""
161
- return seq == self.pad_idx
162
-
163
- def create_look_ahead_mask(self, size):
164
- """Create look-ahead mask for decoder (upper triangular matrix)"""
165
- mask = torch.triu(torch.ones(size, size), diagonal=1)
166
- return mask.bool()
167
-
168
- def encode(self, src, src_key_padding_mask=None):
169
- """Encode source sequence"""
170
- # src shape: (batch_size, src_seq_len)
171
- # Convert to (src_seq_len, batch_size, d_model)
172
-
173
- # Source embedding + positional encoding
174
- src_emb = self.src_embedding(src) * math.sqrt(self.d_model) # (batch, seq, d_model)
175
- src_emb = src_emb.transpose(0, 1) # (seq, batch, d_model)
176
- src_emb = self.pos_encoding(src_emb)
177
- src_emb = self.dropout(src_emb)
178
-
179
- # Pass through encoder layers
180
- enc_output = src_emb
181
- for layer in self.encoder_layers:
182
- enc_output = layer(enc_output, key_padding_mask=src_key_padding_mask)
183
-
184
- return enc_output
185
-
186
- def decode(self, tgt, enc_output, tgt_mask=None, memory_key_padding_mask=None, tgt_key_padding_mask=None):
187
- """Decode target sequence"""
188
- # tgt shape: (batch_size, tgt_seq_len)
189
- # Convert to (tgt_seq_len, batch_size, d_model)
190
-
191
- # Target embedding + positional encoding
192
- tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model) # (batch, seq, d_model)
193
- tgt_emb = tgt_emb.transpose(0, 1) # (seq, batch, d_model)
194
- tgt_emb = self.pos_encoding(tgt_emb)
195
- tgt_emb = self.dropout(tgt_emb)
196
-
197
- # Pass through decoder layers
198
- dec_output = tgt_emb
199
- for layer in self.decoder_layers:
200
- dec_output = layer(
201
- dec_output,
202
- enc_output,
203
- tgt_mask=tgt_mask,
204
- memory_key_padding_mask=memory_key_padding_mask,
205
- tgt_key_padding_mask=tgt_key_padding_mask
206
- )
207
-
208
- return dec_output
209
-
210
- def forward(self, src, tgt):
211
- """Forward pass"""
212
- # src shape: (batch_size, src_seq_len)
213
- # tgt shape: (batch_size, tgt_seq_len)
214
-
215
- batch_size, src_seq_len = src.shape
216
- batch_size, tgt_seq_len = tgt.shape
217
-
218
- # Create masks
219
- src_key_padding_mask = self.create_padding_mask(src) # (batch, src_seq)
220
- tgt_key_padding_mask = self.create_padding_mask(tgt) # (batch, tgt_seq)
221
- tgt_mask = self.create_look_ahead_mask(tgt_seq_len).to(tgt.device) # (tgt_seq, tgt_seq)
222
-
223
- # Encode
224
- enc_output = self.encode(src, src_key_padding_mask)
225
-
226
- # Decode
227
- dec_output = self.decode(
228
- tgt,
229
- enc_output,
230
- tgt_mask=tgt_mask,
231
- memory_key_padding_mask=src_key_padding_mask,
232
- tgt_key_padding_mask=tgt_key_padding_mask
233
- )
234
-
235
- # Final linear transformation
236
- # Convert back to (batch, seq, d_model)
237
- dec_output = dec_output.transpose(0, 1)
238
- output = self.linear(dec_output)
239
-
240
- # Apply softmax to get probabilities
241
- output_probs = F.softmax(output, dim=-1)
242
-
243
- return output_probs
244
-
245
- def generate(self, src, max_len=50, start_token=1, end_token=2):
246
- """Generate sequence using greedy decoding"""
247
- self.eval()
248
- device = src.device
249
- batch_size = src.size(0)
250
-
251
- # Encode source
252
- src_key_padding_mask = self.create_padding_mask(src)
253
- enc_output = self.encode(src, src_key_padding_mask)
254
-
255
- # Initialize target with start token
256
- tgt = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
257
-
258
- for i in range(max_len - 1):
259
- # Create masks for current target
260
- tgt_key_padding_mask = self.create_padding_mask(tgt)
261
- tgt_mask = self.create_look_ahead_mask(tgt.size(1)).to(device)
262
-
263
- # Decode
264
- dec_output = self.decode(
265
- tgt,
266
- enc_output,
267
- tgt_mask=tgt_mask,
268
- memory_key_padding_mask=src_key_padding_mask,
269
- tgt_key_padding_mask=tgt_key_padding_mask
270
- )
271
-
272
- # Get next token probabilities
273
- dec_output = dec_output.transpose(0, 1) # (batch, seq, d_model)
274
- next_token_logits = self.linear(dec_output[:, -1, :]) # (batch, vocab_size)
275
- next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # (batch, 1)
276
-
277
- # Append to target sequence
278
- tgt = torch.cat([tgt, next_token], dim=1)
279
-
280
- # Check if all sequences have generated end token
281
- if (next_token == end_token).all():
282
- break
283
-
284
  return tgt
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class PositionalEncoding(nn.Module):
7
+ def __init__(self, d_model, max_len=5000):
8
+ super().__init__()
9
+
10
+ pe = torch.zeros(max_len, d_model)
11
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
12
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
13
+
14
+ pe[:, 0::2] = torch.sin(position * div_term)
15
+ pe[:, 1::2] = torch.cos(position * div_term)
16
+ pe = pe.unsqueeze(0).transpose(0, 1)
17
+
18
+ self.register_buffer('pe', pe)
19
+
20
+ def forward(self, x):
21
+ return x + self.pe[:x.size(0), :]
22
+
23
+ class FeedForward(nn.Module):
24
+ def __init__(self, d_model, d_ff):
25
+ super().__init__()
26
+ self.linear1 = nn.Linear(d_model, d_ff)
27
+ self.linear2 = nn.Linear(d_ff, d_model)
28
+
29
+ def forward(self, x):
30
+ return self.linear2(F.relu(self.linear1(x)))
31
+
32
+ class EncoderLayer(nn.Module):
33
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
34
+ super().__init__()
35
+ # Using PyTorch's built-in MultiheadAttention
36
+ self.self_attention = nn.MultiheadAttention(
37
+ embed_dim=d_model,
38
+ num_heads=n_heads,
39
+ dropout=dropout,
40
+ batch_first=False # PyTorch default: (seq_len, batch, embed_dim)
41
+ )
42
+ self.feed_forward = FeedForward(d_model, d_ff)
43
+ self.norm1 = nn.LayerNorm(d_model)
44
+ self.norm2 = nn.LayerNorm(d_model)
45
+ self.dropout = nn.Dropout(dropout)
46
+
47
+ def forward(self, x, key_padding_mask=None):
48
+ # x shape: (seq_len, batch_size, d_model)
49
+
50
+ # Multi-head self-attention with residual connection and layer norm
51
+ attn_output, _ = self.self_attention(
52
+ query=x,
53
+ key=x,
54
+ value=x,
55
+ key_padding_mask=key_padding_mask,
56
+ need_weights=False
57
+ )
58
+ x = self.norm1(x + self.dropout(attn_output))
59
+
60
+ # Feed forward with residual connection and layer norm
61
+ ff_output = self.feed_forward(x)
62
+ x = self.norm2(x + self.dropout(ff_output))
63
+
64
+ return x
65
+
66
+ class DecoderLayer(nn.Module):
67
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
68
+ super().__init__()
69
+ # Masked self-attention
70
+ self.masked_self_attention = nn.MultiheadAttention(
71
+ embed_dim=d_model,
72
+ num_heads=n_heads,
73
+ dropout=dropout,
74
+ batch_first=False
75
+ )
76
+ # Cross-attention (decoder attends to encoder)
77
+ self.cross_attention = nn.MultiheadAttention(
78
+ embed_dim=d_model,
79
+ num_heads=n_heads,
80
+ dropout=dropout,
81
+ batch_first=False
82
+ )
83
+ self.feed_forward = FeedForward(d_model, d_ff)
84
+ self.norm1 = nn.LayerNorm(d_model)
85
+ self.norm2 = nn.LayerNorm(d_model)
86
+ self.norm3 = nn.LayerNorm(d_model)
87
+ self.dropout = nn.Dropout(dropout)
88
+
89
+ def forward(self, x, enc_output, tgt_mask=None, memory_key_padding_mask=None, tgt_key_padding_mask=None):
90
+ # x shape: (tgt_seq_len, batch_size, d_model)
91
+ # enc_output shape: (src_seq_len, batch_size, d_model)
92
+
93
+ # Masked multi-head self-attention
94
+ attn_output, _ = self.masked_self_attention(
95
+ query=x,
96
+ key=x,
97
+ value=x,
98
+ attn_mask=tgt_mask,
99
+ key_padding_mask=tgt_key_padding_mask,
100
+ need_weights=False
101
+ )
102
+ x = self.norm1(x + self.dropout(attn_output))
103
+
104
+ # Multi-head cross-attention (decoder attends to encoder)
105
+ attn_output, _ = self.cross_attention(
106
+ query=x,
107
+ key=enc_output,
108
+ value=enc_output,
109
+ key_padding_mask=memory_key_padding_mask,
110
+ need_weights=False
111
+ )
112
+ x = self.norm2(x + self.dropout(attn_output))
113
+
114
+ # Feed forward
115
+ ff_output = self.feed_forward(x)
116
+ x = self.norm3(x + self.dropout(ff_output))
117
+
118
+ return x
119
+
120
+ class Transformer(nn.Module):
121
+ def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_heads=8,
122
+ n_encoder_layers=6, n_decoder_layers=6, d_ff=2048, dropout=0.1, pad_idx=0):
123
+ super().__init__()
124
+
125
+ self.d_model = d_model
126
+ self.pad_idx = pad_idx
127
+
128
+ # Embeddings
129
+ self.src_embedding = nn.Embedding(src_vocab_size, d_model, padding_idx=pad_idx)
130
+ self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model, padding_idx=pad_idx)
131
+
132
+ # Positional encodings
133
+ self.pos_encoding = PositionalEncoding(d_model)
134
+
135
+ # Encoder layers
136
+ self.encoder_layers = nn.ModuleList([
137
+ EncoderLayer(d_model, n_heads, d_ff, dropout)
138
+ for _ in range(n_encoder_layers)
139
+ ])
140
+
141
+ # Decoder layers
142
+ self.decoder_layers = nn.ModuleList([
143
+ DecoderLayer(d_model, n_heads, d_ff, dropout)
144
+ for _ in range(n_decoder_layers)
145
+ ])
146
+
147
+ # Output projection
148
+ self.linear = nn.Linear(d_model, tgt_vocab_size)
149
+ self.dropout = nn.Dropout(dropout)
150
+
151
+ # Initialize weights
152
+ self._init_weights()
153
+
154
+ def _init_weights(self):
155
+ for p in self.parameters():
156
+ if p.dim() > 1:
157
+ nn.init.xavier_uniform_(p)
158
+
159
+ def create_padding_mask(self, seq):
160
+ """Create padding mask for sequences (True for padding tokens)"""
161
+ return seq == self.pad_idx
162
+
163
+ def create_look_ahead_mask(self, size):
164
+ """Create look-ahead mask for decoder (upper triangular matrix)"""
165
+ mask = torch.triu(torch.ones(size, size), diagonal=1)
166
+ return mask.bool()
167
+
168
+ def encode(self, src, src_key_padding_mask=None):
169
+ """Encode source sequence"""
170
+ # src shape: (batch_size, src_seq_len)
171
+ # Convert to (src_seq_len, batch_size, d_model)
172
+
173
+ # Source embedding + positional encoding
174
+ src_emb = self.src_embedding(src) * math.sqrt(self.d_model) # (batch, seq, d_model)
175
+ src_emb = src_emb.transpose(0, 1) # (seq, batch, d_model)
176
+ src_emb = self.pos_encoding(src_emb)
177
+ src_emb = self.dropout(src_emb)
178
+
179
+ # Pass through encoder layers
180
+ enc_output = src_emb
181
+ for layer in self.encoder_layers:
182
+ enc_output = layer(enc_output, key_padding_mask=src_key_padding_mask)
183
+
184
+ return enc_output
185
+
186
+ def decode(self, tgt, enc_output, tgt_mask=None, memory_key_padding_mask=None, tgt_key_padding_mask=None):
187
+ """Decode target sequence"""
188
+ # tgt shape: (batch_size, tgt_seq_len)
189
+ # Convert to (tgt_seq_len, batch_size, d_model)
190
+
191
+ # Target embedding + positional encoding
192
+ tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model) # (batch, seq, d_model)
193
+ tgt_emb = tgt_emb.transpose(0, 1) # (seq, batch, d_model)
194
+ tgt_emb = self.pos_encoding(tgt_emb)
195
+ tgt_emb = self.dropout(tgt_emb)
196
+
197
+ # Pass through decoder layers
198
+ dec_output = tgt_emb
199
+ for layer in self.decoder_layers:
200
+ dec_output = layer(
201
+ dec_output,
202
+ enc_output,
203
+ tgt_mask=tgt_mask,
204
+ memory_key_padding_mask=memory_key_padding_mask,
205
+ tgt_key_padding_mask=tgt_key_padding_mask
206
+ )
207
+
208
+ return dec_output
209
+
210
+ def forward(self, src, tgt):
211
+ """Forward pass"""
212
+ # src shape: (batch_size, src_seq_len)
213
+ # tgt shape: (batch_size, tgt_seq_len)
214
+
215
+ batch_size, src_seq_len = src.shape
216
+ batch_size, tgt_seq_len = tgt.shape
217
+
218
+ # Create masks
219
+ src_key_padding_mask = self.create_padding_mask(src) # (batch, src_seq)
220
+ tgt_key_padding_mask = self.create_padding_mask(tgt) # (batch, tgt_seq)
221
+ tgt_mask = self.create_look_ahead_mask(tgt_seq_len).to(tgt.device) # (tgt_seq, tgt_seq)
222
+
223
+ # Encode
224
+ enc_output = self.encode(src, src_key_padding_mask)
225
+
226
+ # Decode
227
+ dec_output = self.decode(
228
+ tgt,
229
+ enc_output,
230
+ tgt_mask=tgt_mask,
231
+ memory_key_padding_mask=src_key_padding_mask,
232
+ tgt_key_padding_mask=tgt_key_padding_mask
233
+ )
234
+
235
+ # Final linear transformation
236
+ # Convert back to (batch, seq, d_model)
237
+ dec_output = dec_output.transpose(0, 1)
238
+ output = self.linear(dec_output)
239
+
240
+ # Apply softmax to get probabilities
241
+ output_probs = F.softmax(output, dim=-1)
242
+
243
+ return output_probs
244
+
245
+ def generate(self, src, max_len=50, start_token=1, end_token=2):
246
+ """Generate sequence using greedy decoding"""
247
+ self.eval()
248
+ device = src.device
249
+ batch_size = src.size(0)
250
+
251
+ # Encode source
252
+ src_key_padding_mask = self.create_padding_mask(src)
253
+ enc_output = self.encode(src, src_key_padding_mask)
254
+
255
+ # Initialize target with start token
256
+ tgt = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
257
+
258
+ for i in range(max_len - 1):
259
+ # Create masks for current target
260
+ tgt_key_padding_mask = self.create_padding_mask(tgt)
261
+ tgt_mask = self.create_look_ahead_mask(tgt.size(1)).to(device)
262
+
263
+ # Decode
264
+ dec_output = self.decode(
265
+ tgt,
266
+ enc_output,
267
+ tgt_mask=tgt_mask,
268
+ memory_key_padding_mask=src_key_padding_mask,
269
+ tgt_key_padding_mask=tgt_key_padding_mask
270
+ )
271
+
272
+ # Get next token probabilities
273
+ dec_output = dec_output.transpose(0, 1) # (batch, seq, d_model)
274
+ next_token_logits = self.linear(dec_output[:, -1, :]) # (batch, vocab_size)
275
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # (batch, 1)
276
+
277
+ # Append to target sequence
278
+ tgt = torch.cat([tgt, next_token], dim=1)
279
+
280
+ # Check if all sequences have generated end token
281
+ if (next_token == end_token).all():
282
+ break
283
+
284
  return tgt