jannatulferdaws commited on
Commit
3cff8f8
·
verified ·
1 Parent(s): b19d9f0

train the model

Browse files
Files changed (3) hide show
  1. app.py +17 -0
  2. requirements.txt +3 -0
  3. spa_to_eng.py +413 -0
app.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from transformers import pipeline
4
+
5
+ pipe = pipeline("translation", model="my_model.keras")
6
+
7
+ def predict(text):
8
+ return pipe(text)[0]["translation_text"]
9
+
10
+ demo = gr.Interface(
11
+ fn=predict,
12
+ inputs='text',
13
+ outputs='text',
14
+ )
15
+
16
+ demo.launch()
17
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ tensorflow
2
+ keras
3
+ numpy
spa_to_eng.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[2]:
5
+
6
+
7
+ # We set the backend to TensorFlow. The code works with
8
+ # both `tensorflow` and `torch`. It does not work with JAX
9
+ # due to the behavior of `jax.numpy.tile` in a jit scope
10
+ # (used in `TransformerDecoder.get_causal_attention_mask()`:
11
+ # `tile` in JAX does not support a dynamic `reps` argument.
12
+ # You can make the code work in JAX by wrapping the
13
+ # inside of the `get_causal_attention_mask` method in
14
+ # a decorator to prevent jit compilation:
15
+ # `with jax.ensure_compile_time_eval():`.
16
+ import os
17
+
18
+ os.environ["KERAS_BACKEND"] = "tensorflow"
19
+
20
+ import pathlib
21
+ import random
22
+ import string
23
+ import re
24
+ import numpy as np
25
+
26
+ import tensorflow.data as tf_data
27
+ import tensorflow.strings as tf_strings
28
+
29
+ import keras
30
+ from keras import layers
31
+ from keras import ops
32
+ from keras.layers import TextVectorization
33
+
34
+
35
+ # In[3]:
36
+
37
+
38
+ print(keras.__version__)
39
+
40
+
41
+ # In[4]:
42
+
43
+
44
+ # text_file = keras.utils.get_file(
45
+ # fname="spa-eng.zip",
46
+ # origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",
47
+ # extract=True,
48
+ # )
49
+ text_file = 'data/spa-eng/spa_new.txt'
50
+
51
+
52
+ # In[5]:
53
+
54
+
55
+ with open(text_file, encoding='utf-8') as f:
56
+ lines = f.read().split("\n")[:-1]
57
+ text_pairs = []
58
+ for line in lines:
59
+ eng, spa = line.split("\t")
60
+ spa = "[start] " + spa + " [end]"
61
+ text_pairs.append((eng, spa))
62
+
63
+
64
+ #
65
+
66
+ # In[6]:
67
+
68
+
69
+ for _ in range(5):
70
+ print(random.choice(text_pairs))
71
+
72
+
73
+ # In[7]:
74
+
75
+
76
+ random.shuffle(text_pairs)
77
+ num_val_samples = int(0.15 * len(text_pairs))
78
+ num_train_samples = len(text_pairs) - 2 * num_val_samples
79
+ train_pairs = text_pairs[:num_train_samples]
80
+ val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
81
+ test_pairs = text_pairs[num_train_samples + num_val_samples :]
82
+
83
+ print(f"{len(text_pairs)} total pairs")
84
+ print(f"{len(train_pairs)} training pairs")
85
+ print(f"{len(val_pairs)} validation pairs")
86
+ print(f"{len(test_pairs)} test pairs")
87
+
88
+
89
+ # In[8]:
90
+
91
+
92
+ strip_chars = string.punctuation + "¿"
93
+ strip_chars = strip_chars.replace("[", "")
94
+ strip_chars = strip_chars.replace("]", "")
95
+
96
+ vocab_size = 15000
97
+ sequence_length = 20
98
+ batch_size = 64
99
+
100
+
101
+ def custom_standardization(input_string):
102
+ lowercase = tf_strings.lower(input_string)
103
+ return tf_strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")
104
+
105
+
106
+ eng_vectorization = TextVectorization(
107
+ max_tokens=vocab_size,
108
+ output_mode="int",
109
+ output_sequence_length=sequence_length,
110
+ )
111
+ spa_vectorization = TextVectorization(
112
+ max_tokens=vocab_size,
113
+ output_mode="int",
114
+ output_sequence_length=sequence_length + 1,
115
+ standardize=custom_standardization,
116
+ )
117
+ train_eng_texts = [pair[0] for pair in train_pairs]
118
+ train_spa_texts = [pair[1] for pair in train_pairs]
119
+ eng_vectorization.adapt(train_eng_texts)
120
+ spa_vectorization.adapt(train_spa_texts)
121
+
122
+
123
+ # In[9]:
124
+
125
+
126
+ def format_dataset(eng, spa):
127
+ eng = eng_vectorization(eng)
128
+ spa = spa_vectorization(spa)
129
+ return (
130
+ {
131
+ "encoder_inputs": eng,
132
+ "decoder_inputs": spa[:, :-1],
133
+ },
134
+ spa[:, 1:],
135
+ )
136
+
137
+
138
+ def make_dataset(pairs):
139
+ eng_texts, spa_texts = zip(*pairs)
140
+ eng_texts = list(eng_texts)
141
+ spa_texts = list(spa_texts)
142
+ dataset = tf_data.Dataset.from_tensor_slices((eng_texts, spa_texts))
143
+ dataset = dataset.batch(batch_size)
144
+ dataset = dataset.map(format_dataset)
145
+ return dataset.cache().shuffle(2048).prefetch(16)
146
+
147
+
148
+ train_ds = make_dataset(train_pairs)
149
+ val_ds = make_dataset(val_pairs)
150
+
151
+
152
+ # In[10]:
153
+
154
+
155
+ for inputs, targets in train_ds.take(1):
156
+ print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
157
+ print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
158
+ print(f"targets.shape: {targets.shape}")
159
+
160
+
161
+ # In[12]:
162
+
163
+
164
+ print(keras.__version__)
165
+
166
+
167
+ # In[11]:
168
+
169
+
170
+ import keras.ops as ops
171
+
172
+
173
+ class TransformerEncoder(layers.Layer):
174
+ def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
175
+ super().__init__(**kwargs)
176
+ self.embed_dim = embed_dim
177
+ self.dense_dim = dense_dim
178
+ self.num_heads = num_heads
179
+ self.attention = layers.MultiHeadAttention(
180
+ num_heads=num_heads, key_dim=embed_dim
181
+ )
182
+ self.dense_proj = keras.Sequential(
183
+ [
184
+ layers.Dense(dense_dim, activation="relu"),
185
+ layers.Dense(embed_dim),
186
+ ]
187
+ )
188
+ self.layernorm_1 = layers.LayerNormalization()
189
+ self.layernorm_2 = layers.LayerNormalization()
190
+ self.supports_masking = True
191
+
192
+ def call(self, inputs, mask=None):
193
+ if mask is not None:
194
+ padding_mask = ops.cast(mask[:, None, :], dtype="int32")
195
+ else:
196
+ padding_mask = None
197
+
198
+ attention_output = self.attention(
199
+ query=inputs, value=inputs, key=inputs, attention_mask=padding_mask
200
+ )
201
+ proj_input = self.layernorm_1(inputs + attention_output)
202
+ proj_output = self.dense_proj(proj_input)
203
+ return self.layernorm_2(proj_input + proj_output)
204
+
205
+ def get_config(self):
206
+ config = super().get_config()
207
+ config.update(
208
+ {
209
+ "embed_dim": self.embed_dim,
210
+ "dense_dim": self.dense_dim,
211
+ "num_heads": self.num_heads,
212
+ }
213
+ )
214
+ return config
215
+
216
+
217
+ class PositionalEmbedding(layers.Layer):
218
+ def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
219
+ super().__init__(**kwargs)
220
+ self.token_embeddings = layers.Embedding(
221
+ input_dim=vocab_size, output_dim=embed_dim
222
+ )
223
+ self.position_embeddings = layers.Embedding(
224
+ input_dim=sequence_length, output_dim=embed_dim
225
+ )
226
+ self.sequence_length = sequence_length
227
+ self.vocab_size = vocab_size
228
+ self.embed_dim = embed_dim
229
+
230
+ def call(self, inputs):
231
+ length = ops.shape(inputs)[-1]
232
+ positions = ops.arange(0, length, 1)
233
+ embedded_tokens = self.token_embeddings(inputs)
234
+ embedded_positions = self.position_embeddings(positions)
235
+ return embedded_tokens + embedded_positions
236
+
237
+ def compute_mask(self, inputs, mask=None):
238
+ if mask is None:
239
+ return None
240
+ else:
241
+ return ops.not_equal(inputs, 0)
242
+
243
+ def get_config(self):
244
+ config = super().get_config()
245
+ config.update(
246
+ {
247
+ "sequence_length": self.sequence_length,
248
+ "vocab_size": self.vocab_size,
249
+ "embed_dim": self.embed_dim,
250
+ }
251
+ )
252
+ return config
253
+
254
+
255
+ class TransformerDecoder(layers.Layer):
256
+ def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
257
+ super().__init__(**kwargs)
258
+ self.embed_dim = embed_dim
259
+ self.latent_dim = latent_dim
260
+ self.num_heads = num_heads
261
+ self.attention_1 = layers.MultiHeadAttention(
262
+ num_heads=num_heads, key_dim=embed_dim
263
+ )
264
+ self.attention_2 = layers.MultiHeadAttention(
265
+ num_heads=num_heads, key_dim=embed_dim
266
+ )
267
+ self.dense_proj = keras.Sequential(
268
+ [
269
+ layers.Dense(latent_dim, activation="relu"),
270
+ layers.Dense(embed_dim),
271
+ ]
272
+ )
273
+ self.layernorm_1 = layers.LayerNormalization()
274
+ self.layernorm_2 = layers.LayerNormalization()
275
+ self.layernorm_3 = layers.LayerNormalization()
276
+ self.supports_masking = True
277
+
278
+ def call(self, inputs, encoder_outputs, mask=None):
279
+ causal_mask = self.get_causal_attention_mask(inputs)
280
+ if mask is not None:
281
+ padding_mask = ops.cast(mask[:, None, :], dtype="int32")
282
+ padding_mask = ops.minimum(padding_mask, causal_mask)
283
+ else:
284
+ padding_mask = None
285
+
286
+ attention_output_1 = self.attention_1(
287
+ query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
288
+ )
289
+ out_1 = self.layernorm_1(inputs + attention_output_1)
290
+
291
+ attention_output_2 = self.attention_2(
292
+ query=out_1,
293
+ value=encoder_outputs,
294
+ key=encoder_outputs,
295
+ attention_mask=padding_mask,
296
+ )
297
+ out_2 = self.layernorm_2(out_1 + attention_output_2)
298
+
299
+ proj_output = self.dense_proj(out_2)
300
+ return self.layernorm_3(out_2 + proj_output)
301
+
302
+ def get_causal_attention_mask(self, inputs):
303
+ input_shape = ops.shape(inputs)
304
+ batch_size, sequence_length = input_shape[0], input_shape[1]
305
+ i = ops.arange(sequence_length)[:, None]
306
+ j = ops.arange(sequence_length)
307
+ mask = ops.cast(i >= j, dtype="int32")
308
+ mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
309
+ mult = ops.concatenate(
310
+ [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
311
+ axis=0,
312
+ )
313
+ return ops.tile(mask, mult)
314
+
315
+ def get_config(self):
316
+ config = super().get_config()
317
+ config.update(
318
+ {
319
+ "embed_dim": self.embed_dim,
320
+ "latent_dim": self.latent_dim,
321
+ "num_heads": self.num_heads,
322
+ }
323
+ )
324
+ return config
325
+
326
+
327
+ # In[12]:
328
+
329
+
330
+ embed_dim = 256
331
+ latent_dim = 2048
332
+ num_heads = 8
333
+
334
+ encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")
335
+ x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)
336
+ encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)
337
+ encoder = keras.Model(encoder_inputs, encoder_outputs)
338
+
339
+ decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
340
+ encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
341
+ x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)
342
+ x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)
343
+ x = layers.Dropout(0.5)(x)
344
+ decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)
345
+ decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)
346
+
347
+ decoder_outputs = decoder([decoder_inputs, encoder_outputs])
348
+ transformer = keras.Model(
349
+ [encoder_inputs, decoder_inputs], decoder_outputs, name="transformer"
350
+ )
351
+
352
+
353
+ # In[15]:
354
+
355
+
356
+ epochs = 1 # This should be at least 30 for convergence
357
+
358
+ transformer.summary()
359
+ transformer.compile(
360
+ "rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
361
+ )
362
+ transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)
363
+
364
+
365
+ # In[ ]:
366
+
367
+
368
+ spa_vocab = spa_vectorization.get_vocabulary()
369
+ spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))
370
+ max_decoded_sentence_length = 20
371
+
372
+
373
+ def decode_sequence(input_sentence):
374
+ tokenized_input_sentence = eng_vectorization([input_sentence])
375
+ decoded_sentence = "[start]"
376
+ for i in range(max_decoded_sentence_length):
377
+ tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]
378
+ predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])
379
+
380
+ # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here
381
+ sampled_token_index = ops.convert_to_numpy(
382
+ ops.argmax(predictions[0, i, :])
383
+ ).item(0)
384
+ sampled_token = spa_index_lookup[sampled_token_index]
385
+ decoded_sentence += " " + sampled_token
386
+
387
+ if sampled_token == "[end]":
388
+ break
389
+ return decoded_sentence
390
+
391
+
392
+ test_eng_texts = [pair[0] for pair in test_pairs]
393
+ for _ in range(30):
394
+ input_sentence = random.choice(test_eng_texts)
395
+ translated = decode_sequence(input_sentence)
396
+ print(f'English: {input_sentence}')
397
+ print(f'Spanish: {translated}')
398
+
399
+
400
+
401
+ # In[19]:
402
+
403
+
404
+ from keras.utils import plot_model
405
+
406
+ plot_model(transformer, to_file='models/model_trn_plot.png', show_shapes=True, show_layer_names=True)
407
+
408
+
409
+ # In[21]:
410
+
411
+
412
+ transformer.save('my_model.keras')
413
+