Omnia-cy commited on
Commit
2de6d24
·
verified ·
1 Parent(s): 72759f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -84
app.py CHANGED
@@ -11,9 +11,9 @@ import gradio as gr
11
  with open("config.json") as f:
12
  config = json.load(f)
13
 
14
- pad_id = config["pad_id"]
15
- bos_id = config["bos_id"]
16
- eos_id = config["eos_id"]
17
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
@@ -29,12 +29,14 @@ sp_ar.load("sp_ar.model")
29
 
30
 
31
  # =========================
32
- # Model Classes (same as training)
33
  # =========================
 
34
  class MultiHeadAttention(nn.Module):
35
  def __init__(self, d_model, num_heads):
36
  super().__init__()
37
  assert d_model % num_heads == 0
 
38
  self.d_model = d_model
39
  self.num_heads = num_heads
40
  self.d_k = d_model // num_heads
@@ -44,35 +46,39 @@ class MultiHeadAttention(nn.Module):
44
  self.W_v = nn.Linear(d_model, d_model)
45
  self.W_o = nn.Linear(d_model, d_model)
46
 
47
- def split(self, x):
 
 
 
 
 
 
 
 
 
48
  B, T, D = x.size()
49
  return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
50
 
51
- def combine(self, x):
52
  B, H, T, D = x.size()
53
  return x.transpose(1, 2).contiguous().view(B, T, self.d_model)
54
 
55
- def attention(self, q, k, v, mask=None):
56
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
57
- if mask is not None:
58
- scores = scores.masked_fill(mask == 0, -1e9)
59
- return torch.softmax(scores, dim=-1) @ v
60
-
61
- def forward(self, q, k, v, mask=None):
62
- q = self.split(self.W_q(q))
63
- k = self.split(self.W_k(k))
64
- v = self.split(self.W_v(v))
65
 
66
- out = self.attention(q, k, v, mask)
67
- return self.W_o(self.combine(out))
68
 
69
 
70
- class FFN(nn.Module):
71
- def __init__(self, d_model, d_ff):
72
  super().__init__()
73
  self.net = nn.Sequential(
74
  nn.Linear(d_model, d_ff),
75
  nn.ReLU(),
 
76
  nn.Linear(d_ff, d_model)
77
  )
78
 
@@ -80,96 +86,117 @@ class FFN(nn.Module):
80
  return self.net(x)
81
 
82
 
83
- class PosEnc(nn.Module):
84
- def __init__(self, d_model, max_len):
85
  super().__init__()
 
 
 
86
  pe = torch.zeros(max_len, d_model)
87
- pos = torch.arange(0, max_len).unsqueeze(1)
88
- div = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
89
 
90
- pe[:, 0::2] = torch.sin(pos * div)
91
- pe[:, 1::2] = torch.cos(pos * div)
92
 
93
- self.pe = pe.unsqueeze(0)
 
 
 
94
 
95
  def forward(self, x):
96
- return x + self.pe[:, :x.size(1)].to(x.device)
 
97
 
98
 
99
  class EncoderLayer(nn.Module):
100
- def __init__(self, d_model, heads, d_ff):
101
  super().__init__()
102
- self.attn = MultiHeadAttention(d_model, heads)
103
- self.ffn = FFN(d_model, d_ff)
 
 
104
  self.norm1 = nn.LayerNorm(d_model)
105
  self.norm2 = nn.LayerNorm(d_model)
106
 
 
 
107
  def forward(self, x, mask):
108
- x = self.norm1(x + self.attn(x, x, x, mask))
109
- x = self.norm2(x + self.ffn(x))
110
  return x
111
 
112
 
113
  class DecoderLayer(nn.Module):
114
- def __init__(self, d_model, heads, d_ff):
115
  super().__init__()
116
- self.self_attn = MultiHeadAttention(d_model, heads)
117
- self.cross_attn = MultiHeadAttention(d_model, heads)
118
- self.ffn = FFN(d_model, d_ff)
119
-
120
- self.n1 = nn.LayerNorm(d_model)
121
- self.n2 = nn.LayerNorm(d_model)
122
- self.n3 = nn.LayerNorm(d_model)
123
-
124
- def forward(self, x, enc, src_mask, tgt_mask):
125
- x = self.n1(x + self.self_attn(x, x, x, tgt_mask))
126
- x = self.n2(x + self.cross_attn(x, enc, enc, src_mask))
127
- x = self.n3(x + self.ffn(x))
 
 
 
128
  return x
129
 
130
 
131
  class Transformer(nn.Module):
132
- def __init__(self):
 
 
 
133
  super().__init__()
134
 
135
- self.d_model = config["d_model"]
136
 
137
- self.enc_emb = nn.Embedding(config["src_vocab_size"], self.d_model, padding_idx=0)
138
- self.dec_emb = nn.Embedding(config["tgt_vocab_size"], self.d_model, padding_idx=0)
139
 
140
- self.pos = PosEnc(self.d_model, config["max_src_len"])
141
 
142
- self.enc_layers = nn.ModuleList([
143
- EncoderLayer(self.d_model, config["num_heads"], config["d_ff"])
144
- for _ in range(config["num_layers"])
145
  ])
146
 
147
- self.dec_layers = nn.ModuleList([
148
- DecoderLayer(self.d_model, config["num_heads"], config["d_ff"])
149
- for _ in range(config["num_layers"])
150
  ])
151
 
152
- self.fc = nn.Linear(self.d_model, config["tgt_vocab_size"])
153
 
154
- def masks(self, src, tgt):
155
  src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
156
- tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
157
- size = tgt.size(1)
158
- causal = torch.tril(torch.ones(size, size)).bool().to(tgt.device)
159
- return src_mask, tgt_mask & causal
 
 
 
 
 
160
 
161
  def forward(self, src, tgt):
162
- src_mask, tgt_mask = self.masks(src, tgt)
163
 
164
- src = self.pos(self.enc_emb(src))
165
- tgt = self.pos(self.dec_emb(tgt))
166
 
167
  enc = src
168
- for layer in self.enc_layers:
169
  enc = layer(enc, src_mask)
170
 
171
  dec = tgt
172
- for layer in self.dec_layers:
173
  dec = layer(dec, enc, src_mask, tgt_mask)
174
 
175
  return self.fc(dec)
@@ -178,22 +205,31 @@ class Transformer(nn.Module):
178
  # =========================
179
  # Load model
180
  # =========================
181
- model = Transformer().to(device)
 
 
 
 
 
 
 
 
 
182
  model.load_state_dict(torch.load("best_model.pt", map_location=device))
183
  model.eval()
184
 
185
 
186
  # =========================
187
- # Inference
188
  # =========================
189
- def translate(sentence):
190
 
191
- tokens = sp_en.encode(sentence)
192
- tokens = [bos_id] + tokens + [eos_id]
193
 
194
- src = torch.tensor(tokens).unsqueeze(0).to(device)
195
 
196
- out = [bos_id]
197
 
198
  for _ in range(50):
199
 
@@ -203,25 +239,21 @@ def translate(sentence):
203
  pred = model(src, tgt)
204
 
205
  next_token = pred[0, -1].argmax().item()
206
-
207
  out.append(next_token)
208
 
209
- if next_token == eos_id:
210
  break
211
 
212
- result = sp_ar.decode([t for t in out if t not in [bos_id, eos_id, pad_id]])
213
  return result
214
 
215
 
216
  # =========================
217
  # UI
218
  # =========================
219
- demo = gr.Interface(
220
  fn=translate,
221
  inputs="text",
222
  outputs="text",
223
- title="ArabicEnglish Translator (Transformer)",
224
- description="Enter English sentence and get Arabic translation"
225
- )
226
-
227
- demo.launch()
 
11
  with open("config.json") as f:
12
  config = json.load(f)
13
 
14
+ padIndex = config["pad_id"]
15
+ BOSIndex = config["bos_id"]
16
+ EOSIndex = config["eos_id"]
17
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
 
29
 
30
 
31
  # =========================
32
+ # MODEL (EXACT TRAINING VERSION)
33
  # =========================
34
+
35
  class MultiHeadAttention(nn.Module):
36
  def __init__(self, d_model, num_heads):
37
  super().__init__()
38
  assert d_model % num_heads == 0
39
+
40
  self.d_model = d_model
41
  self.num_heads = num_heads
42
  self.d_k = d_model // num_heads
 
46
  self.W_v = nn.Linear(d_model, d_model)
47
  self.W_o = nn.Linear(d_model, d_model)
48
 
49
+ def scaled_dot_product_attention(self, Q, K, V, mask=None):
50
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
51
+
52
+ if mask is not None:
53
+ scores = scores.masked_fill(mask == 0, -1e9)
54
+
55
+ attn = torch.softmax(scores, dim=-1)
56
+ return torch.matmul(attn, V)
57
+
58
+ def split_heads(self, x):
59
  B, T, D = x.size()
60
  return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
61
 
62
+ def combine_heads(self, x):
63
  B, H, T, D = x.size()
64
  return x.transpose(1, 2).contiguous().view(B, T, self.d_model)
65
 
66
+ def forward(self, Q, K, V, mask=None):
67
+ Q = self.split_heads(self.W_q(Q))
68
+ K = self.split_heads(self.W_k(K))
69
+ V = self.split_heads(self.W_v(V))
 
 
 
 
 
 
70
 
71
+ out = self.scaled_dot_product_attention(Q, K, V, mask)
72
+ return self.W_o(self.combine_heads(out))
73
 
74
 
75
+ class PositionWiseFeedForward(nn.Module):
76
+ def __init__(self, d_model, d_ff, dropout=0.1):
77
  super().__init__()
78
  self.net = nn.Sequential(
79
  nn.Linear(d_model, d_ff),
80
  nn.ReLU(),
81
+ nn.Dropout(dropout),
82
  nn.Linear(d_ff, d_model)
83
  )
84
 
 
86
  return self.net(x)
87
 
88
 
89
+ class PositionalEncoding(nn.Module):
90
+ def __init__(self, d_model, max_len, dropout=0.1):
91
  super().__init__()
92
+
93
+ self.dropout = nn.Dropout(dropout)
94
+
95
  pe = torch.zeros(max_len, d_model)
96
+ position = torch.arange(0, max_len).unsqueeze(1)
 
97
 
98
+ div_term = torch.exp(torch.arange(0, d_model, 2) *
99
+ -(math.log(10000.0) / d_model))
100
 
101
+ pe[:, 0::2] = torch.sin(position * div_term)
102
+ pe[:, 1::2] = torch.cos(position * div_term)
103
+
104
+ self.register_buffer("pe", pe.unsqueeze(0))
105
 
106
  def forward(self, x):
107
+ x = x + self.pe[:, :x.size(1)]
108
+ return self.dropout(x)
109
 
110
 
111
  class EncoderLayer(nn.Module):
112
+ def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
113
  super().__init__()
114
+
115
+ self.self_attn = MultiHeadAttention(d_model, num_heads)
116
+ self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
117
+
118
  self.norm1 = nn.LayerNorm(d_model)
119
  self.norm2 = nn.LayerNorm(d_model)
120
 
121
+ self.dropout = nn.Dropout(dropout)
122
+
123
  def forward(self, x, mask):
124
+ x = self.norm1(x + self.dropout(self.self_attn(x, x, x, mask)))
125
+ x = self.norm2(x + self.dropout(self.feed_forward(x)))
126
  return x
127
 
128
 
129
  class DecoderLayer(nn.Module):
130
+ def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
131
  super().__init__()
132
+
133
+ self.self_attn = MultiHeadAttention(d_model, num_heads)
134
+ self.cross_attn = MultiHeadAttention(d_model, num_heads)
135
+ self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
136
+
137
+ self.norm1 = nn.LayerNorm(d_model)
138
+ self.norm2 = nn.LayerNorm(d_model)
139
+ self.norm3 = nn.LayerNorm(d_model)
140
+
141
+ self.dropout = nn.Dropout(dropout)
142
+
143
+ def forward(self, x, enc_out, src_mask, tgt_mask):
144
+ x = self.norm1(x + self.dropout(self.self_attn(x, x, x, tgt_mask)))
145
+ x = self.norm2(x + self.dropout(self.cross_attn(x, enc_out, enc_out, src_mask)))
146
+ x = self.norm3(x + self.dropout(self.feed_forward(x)))
147
  return x
148
 
149
 
150
  class Transformer(nn.Module):
151
+ def __init__(self, src_vocab, tgt_vocab,
152
+ d_model=256, num_heads=4, num_layers=3,
153
+ d_ff=512, max_len=100):
154
+
155
  super().__init__()
156
 
157
+ self.d_model = d_model
158
 
159
+ self.encoder_embedding = nn.Embedding(src_vocab, d_model, padding_idx=0)
160
+ self.decoder_embedding = nn.Embedding(tgt_vocab, d_model, padding_idx=0)
161
 
162
+ self.positional_encoding = PositionalEncoding(d_model, max_len)
163
 
164
+ self.encoder_layers = nn.ModuleList([
165
+ EncoderLayer(d_model, num_heads, d_ff)
166
+ for _ in range(num_layers)
167
  ])
168
 
169
+ self.decoder_layers = nn.ModuleList([
170
+ DecoderLayer(d_model, num_heads, d_ff)
171
+ for _ in range(num_layers)
172
  ])
173
 
174
+ self.fc = nn.Linear(d_model, tgt_vocab)
175
 
176
+ def generate_mask(self, src, tgt):
177
  src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
178
+
179
+ tgt_pad = (tgt != 0).unsqueeze(1).unsqueeze(3)
180
+ T = tgt.size(1)
181
+
182
+ causal = torch.tril(torch.ones(T, T)).bool().to(tgt.device)
183
+
184
+ tgt_mask = tgt_pad & causal
185
+
186
+ return src_mask, tgt_mask
187
 
188
  def forward(self, src, tgt):
189
+ src_mask, tgt_mask = self.generate_mask(src, tgt)
190
 
191
+ src = self.positional_encoding(self.encoder_embedding(src) * math.sqrt(self.d_model))
192
+ tgt = self.positional_encoding(self.decoder_embedding(tgt) * math.sqrt(self.d_model))
193
 
194
  enc = src
195
+ for layer in self.encoder_layers:
196
  enc = layer(enc, src_mask)
197
 
198
  dec = tgt
199
+ for layer in self.decoder_layers:
200
  dec = layer(dec, enc, src_mask, tgt_mask)
201
 
202
  return self.fc(dec)
 
205
  # =========================
206
  # Load model
207
  # =========================
208
+ model = Transformer(
209
+ config["src_vocab_size"],
210
+ config["tgt_vocab_size"],
211
+ config["d_model"],
212
+ config["num_heads"],
213
+ config["num_layers"],
214
+ config["d_ff"],
215
+ max_len=max(config["max_src_len"], config["max_tgt_len"])
216
+ ).to(device)
217
+
218
  model.load_state_dict(torch.load("best_model.pt", map_location=device))
219
  model.eval()
220
 
221
 
222
  # =========================
223
+ # Translation
224
  # =========================
225
+ def translate(text):
226
 
227
+ src = sp_en.encode(text)
228
+ src = [BOSIndex] + src + [EOSIndex]
229
 
230
+ src = torch.tensor(src).unsqueeze(0).to(device)
231
 
232
+ out = [BOSIndex]
233
 
234
  for _ in range(50):
235
 
 
239
  pred = model(src, tgt)
240
 
241
  next_token = pred[0, -1].argmax().item()
 
242
  out.append(next_token)
243
 
244
+ if next_token == EOSIndex:
245
  break
246
 
247
+ result = sp_ar.decode([t for t in out if t not in [BOSIndex, EOSIndex, padIndex]])
248
  return result
249
 
250
 
251
  # =========================
252
  # UI
253
  # =========================
254
+ gr.Interface(
255
  fn=translate,
256
  inputs="text",
257
  outputs="text",
258
+ title="EnglishArabic Transformer",
259
+ ).launch()