Kush26 commited on
Commit
d5dde55
·
verified ·
1 Parent(s): b49941f

Update app/model_def.py

Browse files
Files changed (1) hide show
  1. app/model_def.py +21 -56
app/model_def.py CHANGED
@@ -18,23 +18,12 @@ class PositionalEncoding(nn.Module):
18
  self.d_model = d_model
19
  self.seq_len = seq_len
20
  self.dropout = nn.Dropout(dropout)
21
-
22
- # Create a matrix of shape (seq_len, d_model)
23
  pe = torch.zeros(seq_len, d_model)
24
-
25
- # Create a vector of shape (seq_len, 1)
26
  position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
27
  div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
28
-
29
- # Apply sin to even indices
30
  pe[:, 0::2] = torch.sin(position * div_term)
31
- # Apply cos to odd indices
32
  pe[:, 1::2] = torch.cos(position * div_term)
33
-
34
- # Add a batch dimension
35
- pe = pe.unsqueeze(0) # (1, seq_len, d_model)
36
-
37
- # Register 'pe' as a buffer, so it's not a model parameter
38
  self.register_buffer('pe', pe)
39
 
40
  def forward(self, x):
@@ -45,8 +34,8 @@ class LayerNorm(nn.Module):
45
  def __init__(self, d_model: int, epsilon: float = 1e-6):
46
  super().__init__()
47
  self.epsilon = epsilon
48
- self.gamma = nn.Parameter(torch.ones(d_model)) # Multiplicative
49
- self.beta = nn.Parameter(torch.zeros(d_model)) # Additive
50
 
51
  def forward(self, x):
52
  mean = x.mean(dim=-1, keepdim=True)
@@ -61,7 +50,6 @@ class FeedForward(nn.Module):
61
  self.dropout = nn.Dropout(dropout)
62
 
63
  def forward(self, x):
64
- # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, d_ff) -> (Batch, Seq_Len, d_model)
65
  return self.layer2(self.dropout(torch.relu(self.layer1(x))))
66
 
67
  class MHA(nn.Module):
@@ -70,7 +58,6 @@ class MHA(nn.Module):
70
  self.d_model = d_model
71
  self.h = h
72
  assert d_model % h == 0, "d_model must be divisible by h"
73
-
74
  self.d_k = d_model // h
75
  self.w_q = nn.Linear(d_model, d_model)
76
  self.w_k = nn.Linear(d_model, d_model)
@@ -84,28 +71,20 @@ class MHA(nn.Module):
84
  attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
85
  if mask is not None:
86
  attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
87
-
88
  attention_scores = attention_scores.softmax(dim=-1)
89
  if dropout is not None:
90
  attention_scores = dropout(attention_scores)
91
-
92
  return (attention_scores @ value), attention_scores
93
 
94
  def forward(self, q, k, v, mask):
95
- query = self.w_q(q) # (Batch, Seq_Len, d_model)
96
- key = self.w_k(k) # (Batch, Seq_Len, d_model)
97
- value = self.w_v(v) # (Batch, Seq_Len, d_model)
98
-
99
- # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, h, d_k) -> (Batch, h, Seq_Len, d_k)
100
  query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
101
  key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
102
  value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
103
-
104
  x, self.attention_scores = MHA.attention(query, key, value, mask, self.dropout)
105
-
106
- # (Batch, h, Seq_Len, d_k) -> (Batch, Seq_Len, h, d_k) -> (Batch, Seq_Len, d_model)
107
  x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
108
-
109
  return self.w_o(x)
110
 
111
  class SkipConnection(nn.Module):
@@ -115,19 +94,20 @@ class SkipConnection(nn.Module):
115
  self.norm = LayerNorm(d_model)
116
 
117
  def forward(self, x, sublayer):
118
- # Pre-Norm architecture
119
  return x + self.dropout(sublayer(self.norm(x)))
120
 
121
  class EncoderBlock(nn.Module):
122
  def __init__(self, self_attention: MHA, ffn: FeedForward, d_model: int, dropout: float):
123
  super().__init__()
124
- self.attention = self_attention
 
125
  self.ffn = ffn
126
- self.residual = nn.ModuleList([SkipConnection(d_model, dropout) for _ in range(2)])
127
-
 
128
  def forward(self, x, src_mask):
129
- x = self.residual[0](x, lambda x: self.attention(x, x, x, src_mask))
130
- x = self.residual[1](x, self.ffn)
131
  return x
132
 
133
  class Encoder(nn.Module):
@@ -144,15 +124,17 @@ class Encoder(nn.Module):
144
  class DecoderBlock(nn.Module):
145
  def __init__(self, self_attention: MHA, cross_attention: MHA, ffn: FeedForward, d_model: int, dropout: float):
146
  super().__init__()
147
- self.attention = self_attention
 
148
  self.cross_attention = cross_attention
149
  self.ffn = ffn
150
- self.residual = nn.ModuleList([SkipConnection(d_model, dropout) for _ in range(3)])
151
-
 
152
  def forward(self, x, encoder_output, src_mask, trg_mask):
153
- x = self.residual[0](x, lambda x: self.attention(x, x, x, trg_mask))
154
- x = self.residual[1](x, lambda x: self.cross_attention(x, encoder_output, encoder_output, src_mask))
155
- x = self.residual[2](x, self.ffn)
156
  return x
157
 
158
  class Decoder(nn.Module):
@@ -172,7 +154,6 @@ class Output(nn.Module):
172
  self.proj = nn.Linear(d_model, vocab_size)
173
 
174
  def forward(self, x):
175
- # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, vocab_size)
176
  return self.proj(x)
177
 
178
  class Transformer(nn.Module):
@@ -200,23 +181,16 @@ class Transformer(nn.Module):
200
  return self.output_layer(x)
201
 
202
  def BuildTransformer(src_vocab_size: int, trg_vocab_size: int, src_seq_len: int, trg_seq_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048) -> Transformer:
203
- # Create the embedding layers
204
  src_embed = InputEmbedding(d_model, src_vocab_size)
205
  trg_embed = InputEmbedding(d_model, trg_vocab_size)
206
-
207
- # Create the positional encoding layers
208
  src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
209
  trg_pos = PositionalEncoding(d_model, trg_seq_len, dropout)
210
-
211
- # Create the encoder blocks
212
  encoder_blocks = []
213
  for _ in range(N):
214
  encoder_self_attention = MHA(d_model, h, dropout)
215
  ffn = FeedForward(d_model, d_ff, dropout)
216
  encoder_block = EncoderBlock(encoder_self_attention, ffn, d_model, dropout)
217
  encoder_blocks.append(encoder_block)
218
-
219
- # Create the decoder blocks
220
  decoder_blocks = []
221
  for _ in range(N):
222
  decoder_self_attention = MHA(d_model, h, dropout)
@@ -224,20 +198,11 @@ def BuildTransformer(src_vocab_size: int, trg_vocab_size: int, src_seq_len: int,
224
  ffn = FeedForward(d_model, d_ff, dropout)
225
  decoder_block = DecoderBlock(decoder_self_attention, cross_attention, ffn, d_model, dropout)
226
  decoder_blocks.append(decoder_block)
227
-
228
- # Create the encoder and decoder
229
  encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
230
  decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
231
-
232
- # Create the projection layer
233
  projection = Output(d_model, trg_vocab_size)
234
-
235
- # Create the transformer
236
  transformer = Transformer(encoder, decoder, src_embed, trg_embed, src_pos, trg_pos, projection)
237
-
238
- # Initialize parameters
239
  for p in transformer.parameters():
240
  if p.dim() > 1:
241
  nn.init.xavier_uniform_(p)
242
-
243
  return transformer
 
18
  self.d_model = d_model
19
  self.seq_len = seq_len
20
  self.dropout = nn.Dropout(dropout)
 
 
21
  pe = torch.zeros(seq_len, d_model)
 
 
22
  position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
23
  div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
 
 
24
  pe[:, 0::2] = torch.sin(position * div_term)
 
25
  pe[:, 1::2] = torch.cos(position * div_term)
26
+ pe = pe.unsqueeze(0)
 
 
 
 
27
  self.register_buffer('pe', pe)
28
 
29
  def forward(self, x):
 
34
  def __init__(self, d_model: int, epsilon: float = 1e-6):
35
  super().__init__()
36
  self.epsilon = epsilon
37
+ self.gamma = nn.Parameter(torch.ones(d_model))
38
+ self.beta = nn.Parameter(torch.zeros(d_model))
39
 
40
  def forward(self, x):
41
  mean = x.mean(dim=-1, keepdim=True)
 
50
  self.dropout = nn.Dropout(dropout)
51
 
52
  def forward(self, x):
 
53
  return self.layer2(self.dropout(torch.relu(self.layer1(x))))
54
 
55
  class MHA(nn.Module):
 
58
  self.d_model = d_model
59
  self.h = h
60
  assert d_model % h == 0, "d_model must be divisible by h"
 
61
  self.d_k = d_model // h
62
  self.w_q = nn.Linear(d_model, d_model)
63
  self.w_k = nn.Linear(d_model, d_model)
 
71
  attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
72
  if mask is not None:
73
  attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
 
74
  attention_scores = attention_scores.softmax(dim=-1)
75
  if dropout is not None:
76
  attention_scores = dropout(attention_scores)
 
77
  return (attention_scores @ value), attention_scores
78
 
79
  def forward(self, q, k, v, mask):
80
+ query = self.w_q(q)
81
+ key = self.w_k(k)
82
+ value = self.w_v(v)
 
 
83
  query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
84
  key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
85
  value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
 
86
  x, self.attention_scores = MHA.attention(query, key, value, mask, self.dropout)
 
 
87
  x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
 
88
  return self.w_o(x)
89
 
90
  class SkipConnection(nn.Module):
 
94
  self.norm = LayerNorm(d_model)
95
 
96
  def forward(self, x, sublayer):
 
97
  return x + self.dropout(sublayer(self.norm(x)))
98
 
99
  class EncoderBlock(nn.Module):
100
  def __init__(self, self_attention: MHA, ffn: FeedForward, d_model: int, dropout: float):
101
  super().__init__()
102
+ # Name required by the saved model file
103
+ self.attention = self_attention
104
  self.ffn = ffn
105
+ # Name required by the saved model file
106
+ self.residual = nn.ModuleList([SkipConnection(d_model, dropout) for _ in range(2)])
107
+
108
  def forward(self, x, src_mask):
109
+ x = self.residual[0](x, lambda x: self.attention(x, x, x, src_mask))
110
+ x = self.residual[1](x, self.ffn)
111
  return x
112
 
113
  class Encoder(nn.Module):
 
124
  class DecoderBlock(nn.Module):
125
  def __init__(self, self_attention: MHA, cross_attention: MHA, ffn: FeedForward, d_model: int, dropout: float):
126
  super().__init__()
127
+ # Name required by the saved model file
128
+ self.self_attention = self_attention
129
  self.cross_attention = cross_attention
130
  self.ffn = ffn
131
+ # Name required by the saved model file
132
+ self.residual = nn.ModuleList([SkipConnection(d_model, dropout) for _ in range(3)])
133
+
134
  def forward(self, x, encoder_output, src_mask, trg_mask):
135
+ x = self.residual[0](x, lambda x: self.self_attention(x, x, x, trg_mask))
136
+ x = self.residual[1](x, lambda x: self.cross_attention(x, encoder_output, encoder_output, src_mask))
137
+ x = self.residual[2](x, self.ffn)
138
  return x
139
 
140
  class Decoder(nn.Module):
 
154
  self.proj = nn.Linear(d_model, vocab_size)
155
 
156
  def forward(self, x):
 
157
  return self.proj(x)
158
 
159
  class Transformer(nn.Module):
 
181
  return self.output_layer(x)
182
 
183
  def BuildTransformer(src_vocab_size: int, trg_vocab_size: int, src_seq_len: int, trg_seq_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048) -> Transformer:
 
184
  src_embed = InputEmbedding(d_model, src_vocab_size)
185
  trg_embed = InputEmbedding(d_model, trg_vocab_size)
 
 
186
  src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
187
  trg_pos = PositionalEncoding(d_model, trg_seq_len, dropout)
 
 
188
  encoder_blocks = []
189
  for _ in range(N):
190
  encoder_self_attention = MHA(d_model, h, dropout)
191
  ffn = FeedForward(d_model, d_ff, dropout)
192
  encoder_block = EncoderBlock(encoder_self_attention, ffn, d_model, dropout)
193
  encoder_blocks.append(encoder_block)
 
 
194
  decoder_blocks = []
195
  for _ in range(N):
196
  decoder_self_attention = MHA(d_model, h, dropout)
 
198
  ffn = FeedForward(d_model, d_ff, dropout)
199
  decoder_block = DecoderBlock(decoder_self_attention, cross_attention, ffn, d_model, dropout)
200
  decoder_blocks.append(decoder_block)
 
 
201
  encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
202
  decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
 
 
203
  projection = Output(d_model, trg_vocab_size)
 
 
204
  transformer = Transformer(encoder, decoder, src_embed, trg_embed, src_pos, trg_pos, projection)
 
 
205
  for p in transformer.parameters():
206
  if p.dim() > 1:
207
  nn.init.xavier_uniform_(p)
 
208
  return transformer