Kush26 commited on
Commit
b49941f
·
verified ·
1 Parent(s): cf12f14

Update app/model_def.py

Browse files
Files changed (1) hide show
  1. app/model_def.py +242 -242
app/model_def.py CHANGED
@@ -1,243 +1,243 @@
1
- import torch
2
- import torch.nn as nn
3
- import math
4
-
5
- class InputEmbedding(nn.Module):
6
- def __init__(self, d_model: int, vocab_size: int):
7
- super().__init__()
8
- self.d_model = d_model
9
- self.vocab_size = vocab_size
10
- self.embed = nn.Embedding(vocab_size, d_model)
11
-
12
- def forward(self, x):
13
- return self.embed(x) * math.sqrt(self.d_model)
14
-
15
- class PositionalEncoding(nn.Module):
16
- def __init__(self, d_model: int, seq_len: int, dropout: float):
17
- super().__init__()
18
- self.d_model = d_model
19
- self.seq_len = seq_len
20
- self.dropout = nn.Dropout(dropout)
21
-
22
- # Create a matrix of shape (seq_len, d_model)
23
- pe = torch.zeros(seq_len, d_model)
24
-
25
- # Create a vector of shape (seq_len, 1)
26
- position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
27
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
28
-
29
- # Apply sin to even indices
30
- pe[:, 0::2] = torch.sin(position * div_term)
31
- # Apply cos to odd indices
32
- pe[:, 1::2] = torch.cos(position * div_term)
33
-
34
- # Add a batch dimension
35
- pe = pe.unsqueeze(0) # (1, seq_len, d_model)
36
-
37
- # Register 'pe' as a buffer, so it's not a model parameter
38
- self.register_buffer('pe', pe)
39
-
40
- def forward(self, x):
41
- x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
42
- return self.dropout(x)
43
-
44
- class LayerNorm(nn.Module):
45
- def __init__(self, d_model: int, epsilon: float = 1e-6):
46
- super().__init__()
47
- self.epsilon = epsilon
48
- self.gamma = nn.Parameter(torch.ones(d_model)) # Multiplicative
49
- self.beta = nn.Parameter(torch.zeros(d_model)) # Additive
50
-
51
- def forward(self, x):
52
- mean = x.mean(dim=-1, keepdim=True)
53
- std = x.std(dim=-1, keepdim=True)
54
- return self.gamma * (x - mean) / (std + self.epsilon) + self.beta
55
-
56
- class FeedForward(nn.Module):
57
- def __init__(self, d_model: int, d_ff: int, dropout: float):
58
- super().__init__()
59
- self.layer1 = nn.Linear(d_model, d_ff)
60
- self.layer2 = nn.Linear(d_ff, d_model)
61
- self.dropout = nn.Dropout(dropout)
62
-
63
- def forward(self, x):
64
- # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, d_ff) -> (Batch, Seq_Len, d_model)
65
- return self.layer2(self.dropout(torch.relu(self.layer1(x))))
66
-
67
- class MHA(nn.Module):
68
- def __init__(self, d_model: int, h: int, dropout: float):
69
- super().__init__()
70
- self.d_model = d_model
71
- self.h = h
72
- assert d_model % h == 0, "d_model must be divisible by h"
73
-
74
- self.d_k = d_model // h
75
- self.w_q = nn.Linear(d_model, d_model)
76
- self.w_k = nn.Linear(d_model, d_model)
77
- self.w_v = nn.Linear(d_model, d_model)
78
- self.w_o = nn.Linear(d_model, d_model)
79
- self.dropout = nn.Dropout(dropout)
80
-
81
- @staticmethod
82
- def attention(query, key, value, mask, dropout: nn.Dropout):
83
- d_k = query.shape[-1]
84
- attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
85
- if mask is not None:
86
- attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
87
-
88
- attention_scores = attention_scores.softmax(dim=-1)
89
- if dropout is not None:
90
- attention_scores = dropout(attention_scores)
91
-
92
- return (attention_scores @ value), attention_scores
93
-
94
- def forward(self, q, k, v, mask):
95
- query = self.w_q(q) # (Batch, Seq_Len, d_model)
96
- key = self.w_k(k) # (Batch, Seq_Len, d_model)
97
- value = self.w_v(v) # (Batch, Seq_Len, d_model)
98
-
99
- # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, h, d_k) -> (Batch, h, Seq_Len, d_k)
100
- query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
101
- key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
102
- value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
103
-
104
- x, self.attention_scores = MHA.attention(query, key, value, mask, self.dropout)
105
-
106
- # (Batch, h, Seq_Len, d_k) -> (Batch, Seq_Len, h, d_k) -> (Batch, Seq_Len, d_model)
107
- x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
108
-
109
- return self.w_o(x)
110
-
111
- class SkipConnection(nn.Module):
112
- def __init__(self, d_model: int, dropout: float):
113
- super().__init__()
114
- self.dropout = nn.Dropout(dropout)
115
- self.norm = LayerNorm(d_model)
116
-
117
- def forward(self, x, sublayer):
118
- # Pre-Norm architecture
119
- return x + self.dropout(sublayer(self.norm(x)))
120
-
121
- class EncoderBlock(nn.Module):
122
- def __init__(self, self_attention: MHA, ffn: FeedForward, d_model: int, dropout: float):
123
- super().__init__()
124
- self.self_attention = self_attention
125
- self.ffn = ffn
126
- self.skip_connections = nn.ModuleList([SkipConnection(d_model, dropout) for _ in range(2)])
127
-
128
- def forward(self, x, src_mask):
129
- x = self.skip_connections[0](x, lambda x: self.self_attention(x, x, x, src_mask))
130
- x = self.skip_connections[1](x, self.ffn)
131
- return x
132
-
133
- class Encoder(nn.Module):
134
- def __init__(self, d_model: int, layers: nn.ModuleList):
135
- super().__init__()
136
- self.layers = layers
137
- self.norm = LayerNorm(d_model)
138
-
139
- def forward(self, x, mask):
140
- for layer in self.layers:
141
- x = layer(x, mask)
142
- return self.norm(x)
143
-
144
- class DecoderBlock(nn.Module):
145
- def __init__(self, self_attention: MHA, cross_attention: MHA, ffn: FeedForward, d_model: int, dropout: float):
146
- super().__init__()
147
- self.self_attention = self_attention
148
- self.cross_attention = cross_attention
149
- self.ffn = ffn
150
- self.skip_connections = nn.ModuleList([SkipConnection(d_model, dropout) for _ in range(3)])
151
-
152
- def forward(self, x, encoder_output, src_mask, trg_mask):
153
- x = self.skip_connections[0](x, lambda x: self.self_attention(x, x, x, trg_mask))
154
- x = self.skip_connections[1](x, lambda x: self.cross_attention(x, encoder_output, encoder_output, src_mask))
155
- x = self.skip_connections[2](x, self.ffn)
156
- return x
157
-
158
- class Decoder(nn.Module):
159
- def __init__(self, d_model: int, layers: nn.ModuleList):
160
- super().__init__()
161
- self.layers = layers
162
- self.norm = LayerNorm(d_model)
163
-
164
- def forward(self, x, encoder_output, src_mask, trg_mask):
165
- for layer in self.layers:
166
- x = layer(x, encoder_output, src_mask, trg_mask)
167
- return self.norm(x)
168
-
169
- class Output(nn.Module):
170
- def __init__(self, d_model: int, vocab_size: int):
171
- super().__init__()
172
- self.proj = nn.Linear(d_model, vocab_size)
173
-
174
- def forward(self, x):
175
- # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, vocab_size)
176
- return self.proj(x)
177
-
178
- class Transformer(nn.Module):
179
- def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbedding, trg_embed: InputEmbedding, src_pos: PositionalEncoding, trg_pos: PositionalEncoding, output: Output):
180
- super().__init__()
181
- self.encoder = encoder
182
- self.decoder = decoder
183
- self.src_embed = src_embed
184
- self.trg_embed = trg_embed
185
- self.src_pos = src_pos
186
- self.trg_pos = trg_pos
187
- self.output_layer = output
188
-
189
- def encode(self, src, src_mask):
190
- src = self.src_embed(src)
191
- src = self.src_pos(src)
192
- return self.encoder(src, src_mask)
193
-
194
- def decode(self, encoder_output, src_mask, trg, trg_mask):
195
- trg = self.trg_embed(trg)
196
- trg = self.trg_pos(trg)
197
- return self.decoder(trg, encoder_output, src_mask, trg_mask)
198
-
199
- def project(self, x):
200
- return self.output_layer(x)
201
-
202
- def BuildTransformer(src_vocab_size: int, trg_vocab_size: int, src_seq_len: int, trg_seq_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048) -> Transformer:
203
- # Create the embedding layers
204
- src_embed = InputEmbedding(d_model, src_vocab_size)
205
- trg_embed = InputEmbedding(d_model, trg_vocab_size)
206
-
207
- # Create the positional encoding layers
208
- src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
209
- trg_pos = PositionalEncoding(d_model, trg_seq_len, dropout)
210
-
211
- # Create the encoder blocks
212
- encoder_blocks = []
213
- for _ in range(N):
214
- encoder_self_attention = MHA(d_model, h, dropout)
215
- ffn = FeedForward(d_model, d_ff, dropout)
216
- encoder_block = EncoderBlock(encoder_self_attention, ffn, d_model, dropout)
217
- encoder_blocks.append(encoder_block)
218
-
219
- # Create the decoder blocks
220
- decoder_blocks = []
221
- for _ in range(N):
222
- decoder_self_attention = MHA(d_model, h, dropout)
223
- cross_attention = MHA(d_model, h, dropout)
224
- ffn = FeedForward(d_model, d_ff, dropout)
225
- decoder_block = DecoderBlock(decoder_self_attention, cross_attention, ffn, d_model, dropout)
226
- decoder_blocks.append(decoder_block)
227
-
228
- # Create the encoder and decoder
229
- encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
230
- decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
231
-
232
- # Create the projection layer
233
- projection = Output(d_model, trg_vocab_size)
234
-
235
- # Create the transformer
236
- transformer = Transformer(encoder, decoder, src_embed, trg_embed, src_pos, trg_pos, projection)
237
-
238
- # Initialize parameters
239
- for p in transformer.parameters():
240
- if p.dim() > 1:
241
- nn.init.xavier_uniform_(p)
242
-
243
  return transformer
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class InputEmbedding(nn.Module):
6
+ def __init__(self, d_model: int, vocab_size: int):
7
+ super().__init__()
8
+ self.d_model = d_model
9
+ self.vocab_size = vocab_size
10
+ self.embed = nn.Embedding(vocab_size, d_model)
11
+
12
+ def forward(self, x):
13
+ return self.embed(x) * math.sqrt(self.d_model)
14
+
15
+ class PositionalEncoding(nn.Module):
16
+ def __init__(self, d_model: int, seq_len: int, dropout: float):
17
+ super().__init__()
18
+ self.d_model = d_model
19
+ self.seq_len = seq_len
20
+ self.dropout = nn.Dropout(dropout)
21
+
22
+ # Create a matrix of shape (seq_len, d_model)
23
+ pe = torch.zeros(seq_len, d_model)
24
+
25
+ # Create a vector of shape (seq_len, 1)
26
+ position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
27
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
28
+
29
+ # Apply sin to even indices
30
+ pe[:, 0::2] = torch.sin(position * div_term)
31
+ # Apply cos to odd indices
32
+ pe[:, 1::2] = torch.cos(position * div_term)
33
+
34
+ # Add a batch dimension
35
+ pe = pe.unsqueeze(0) # (1, seq_len, d_model)
36
+
37
+ # Register 'pe' as a buffer, so it's not a model parameter
38
+ self.register_buffer('pe', pe)
39
+
40
+ def forward(self, x):
41
+ x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
42
+ return self.dropout(x)
43
+
44
+ class LayerNorm(nn.Module):
45
+ def __init__(self, d_model: int, epsilon: float = 1e-6):
46
+ super().__init__()
47
+ self.epsilon = epsilon
48
+ self.gamma = nn.Parameter(torch.ones(d_model)) # Multiplicative
49
+ self.beta = nn.Parameter(torch.zeros(d_model)) # Additive
50
+
51
+ def forward(self, x):
52
+ mean = x.mean(dim=-1, keepdim=True)
53
+ std = x.std(dim=-1, keepdim=True)
54
+ return self.gamma * (x - mean) / (std + self.epsilon) + self.beta
55
+
56
+ class FeedForward(nn.Module):
57
+ def __init__(self, d_model: int, d_ff: int, dropout: float):
58
+ super().__init__()
59
+ self.layer1 = nn.Linear(d_model, d_ff)
60
+ self.layer2 = nn.Linear(d_ff, d_model)
61
+ self.dropout = nn.Dropout(dropout)
62
+
63
+ def forward(self, x):
64
+ # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, d_ff) -> (Batch, Seq_Len, d_model)
65
+ return self.layer2(self.dropout(torch.relu(self.layer1(x))))
66
+
67
+ class MHA(nn.Module):
68
+ def __init__(self, d_model: int, h: int, dropout: float):
69
+ super().__init__()
70
+ self.d_model = d_model
71
+ self.h = h
72
+ assert d_model % h == 0, "d_model must be divisible by h"
73
+
74
+ self.d_k = d_model // h
75
+ self.w_q = nn.Linear(d_model, d_model)
76
+ self.w_k = nn.Linear(d_model, d_model)
77
+ self.w_v = nn.Linear(d_model, d_model)
78
+ self.w_o = nn.Linear(d_model, d_model)
79
+ self.dropout = nn.Dropout(dropout)
80
+
81
+ @staticmethod
82
+ def attention(query, key, value, mask, dropout: nn.Dropout):
83
+ d_k = query.shape[-1]
84
+ attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
85
+ if mask is not None:
86
+ attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
87
+
88
+ attention_scores = attention_scores.softmax(dim=-1)
89
+ if dropout is not None:
90
+ attention_scores = dropout(attention_scores)
91
+
92
+ return (attention_scores @ value), attention_scores
93
+
94
+ def forward(self, q, k, v, mask):
95
+ query = self.w_q(q) # (Batch, Seq_Len, d_model)
96
+ key = self.w_k(k) # (Batch, Seq_Len, d_model)
97
+ value = self.w_v(v) # (Batch, Seq_Len, d_model)
98
+
99
+ # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, h, d_k) -> (Batch, h, Seq_Len, d_k)
100
+ query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
101
+ key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
102
+ value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
103
+
104
+ x, self.attention_scores = MHA.attention(query, key, value, mask, self.dropout)
105
+
106
+ # (Batch, h, Seq_Len, d_k) -> (Batch, Seq_Len, h, d_k) -> (Batch, Seq_Len, d_model)
107
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
108
+
109
+ return self.w_o(x)
110
+
111
+ class SkipConnection(nn.Module):
112
+ def __init__(self, d_model: int, dropout: float):
113
+ super().__init__()
114
+ self.dropout = nn.Dropout(dropout)
115
+ self.norm = LayerNorm(d_model)
116
+
117
+ def forward(self, x, sublayer):
118
+ # Pre-Norm architecture
119
+ return x + self.dropout(sublayer(self.norm(x)))
120
+
121
+ class EncoderBlock(nn.Module):
122
+ def __init__(self, self_attention: MHA, ffn: FeedForward, d_model: int, dropout: float):
123
+ super().__init__()
124
+ self.attention = self_attention
125
+ self.ffn = ffn
126
+ self.residual = nn.ModuleList([SkipConnection(d_model, dropout) for _ in range(2)])
127
+
128
+ def forward(self, x, src_mask):
129
+ x = self.residual[0](x, lambda x: self.attention(x, x, x, src_mask))
130
+ x = self.residual[1](x, self.ffn)
131
+ return x
132
+
133
+ class Encoder(nn.Module):
134
+ def __init__(self, d_model: int, layers: nn.ModuleList):
135
+ super().__init__()
136
+ self.layers = layers
137
+ self.norm = LayerNorm(d_model)
138
+
139
+ def forward(self, x, mask):
140
+ for layer in self.layers:
141
+ x = layer(x, mask)
142
+ return self.norm(x)
143
+
144
+ class DecoderBlock(nn.Module):
145
+ def __init__(self, self_attention: MHA, cross_attention: MHA, ffn: FeedForward, d_model: int, dropout: float):
146
+ super().__init__()
147
+ self.attention = self_attention
148
+ self.cross_attention = cross_attention
149
+ self.ffn = ffn
150
+ self.residual = nn.ModuleList([SkipConnection(d_model, dropout) for _ in range(3)])
151
+
152
+ def forward(self, x, encoder_output, src_mask, trg_mask):
153
+ x = self.residual[0](x, lambda x: self.attention(x, x, x, trg_mask))
154
+ x = self.residual[1](x, lambda x: self.cross_attention(x, encoder_output, encoder_output, src_mask))
155
+ x = self.residual[2](x, self.ffn)
156
+ return x
157
+
158
+ class Decoder(nn.Module):
159
+ def __init__(self, d_model: int, layers: nn.ModuleList):
160
+ super().__init__()
161
+ self.layers = layers
162
+ self.norm = LayerNorm(d_model)
163
+
164
+ def forward(self, x, encoder_output, src_mask, trg_mask):
165
+ for layer in self.layers:
166
+ x = layer(x, encoder_output, src_mask, trg_mask)
167
+ return self.norm(x)
168
+
169
+ class Output(nn.Module):
170
+ def __init__(self, d_model: int, vocab_size: int):
171
+ super().__init__()
172
+ self.proj = nn.Linear(d_model, vocab_size)
173
+
174
+ def forward(self, x):
175
+ # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, vocab_size)
176
+ return self.proj(x)
177
+
178
+ class Transformer(nn.Module):
179
+ def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbedding, trg_embed: InputEmbedding, src_pos: PositionalEncoding, trg_pos: PositionalEncoding, output: Output):
180
+ super().__init__()
181
+ self.encoder = encoder
182
+ self.decoder = decoder
183
+ self.src_embed = src_embed
184
+ self.trg_embed = trg_embed
185
+ self.src_pos = src_pos
186
+ self.trg_pos = trg_pos
187
+ self.output_layer = output
188
+
189
+ def encode(self, src, src_mask):
190
+ src = self.src_embed(src)
191
+ src = self.src_pos(src)
192
+ return self.encoder(src, src_mask)
193
+
194
+ def decode(self, encoder_output, src_mask, trg, trg_mask):
195
+ trg = self.trg_embed(trg)
196
+ trg = self.trg_pos(trg)
197
+ return self.decoder(trg, encoder_output, src_mask, trg_mask)
198
+
199
+ def project(self, x):
200
+ return self.output_layer(x)
201
+
202
+ def BuildTransformer(src_vocab_size: int, trg_vocab_size: int, src_seq_len: int, trg_seq_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048) -> Transformer:
203
+ # Create the embedding layers
204
+ src_embed = InputEmbedding(d_model, src_vocab_size)
205
+ trg_embed = InputEmbedding(d_model, trg_vocab_size)
206
+
207
+ # Create the positional encoding layers
208
+ src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
209
+ trg_pos = PositionalEncoding(d_model, trg_seq_len, dropout)
210
+
211
+ # Create the encoder blocks
212
+ encoder_blocks = []
213
+ for _ in range(N):
214
+ encoder_self_attention = MHA(d_model, h, dropout)
215
+ ffn = FeedForward(d_model, d_ff, dropout)
216
+ encoder_block = EncoderBlock(encoder_self_attention, ffn, d_model, dropout)
217
+ encoder_blocks.append(encoder_block)
218
+
219
+ # Create the decoder blocks
220
+ decoder_blocks = []
221
+ for _ in range(N):
222
+ decoder_self_attention = MHA(d_model, h, dropout)
223
+ cross_attention = MHA(d_model, h, dropout)
224
+ ffn = FeedForward(d_model, d_ff, dropout)
225
+ decoder_block = DecoderBlock(decoder_self_attention, cross_attention, ffn, d_model, dropout)
226
+ decoder_blocks.append(decoder_block)
227
+
228
+ # Create the encoder and decoder
229
+ encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
230
+ decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
231
+
232
+ # Create the projection layer
233
+ projection = Output(d_model, trg_vocab_size)
234
+
235
+ # Create the transformer
236
+ transformer = Transformer(encoder, decoder, src_embed, trg_embed, src_pos, trg_pos, projection)
237
+
238
+ # Initialize parameters
239
+ for p in transformer.parameters():
240
+ if p.dim() > 1:
241
+ nn.init.xavier_uniform_(p)
242
+
243
  return transformer