Commit ·
6543d58
1
Parent(s): ccfa333
Updated trainer
Browse files- src/pipes/const.py +2 -0
- src/pipes/data.py +32 -0
- src/pipes/models.py +0 -32
src/pipes/const.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
| 1 |
data_dir: str = "E:/bn_multi_tribe_mt/data/"
|
| 2 |
langs: list[str] = ['bn', 'en', 'gr']
|
| 3 |
MAX_SEQ_LEN = 30
|
|
|
|
|
|
|
|
|
| 1 |
data_dir: str = "E:/bn_multi_tribe_mt/data/"
|
| 2 |
langs: list[str] = ['bn', 'en', 'gr']
|
| 3 |
MAX_SEQ_LEN = 30
|
| 4 |
+
BATCH_SIZE = 64
|
| 5 |
+
BUFFER_SIZE = 10000
|
src/pipes/data.py
CHANGED
|
@@ -3,6 +3,7 @@ import const
|
|
| 3 |
import utils
|
| 4 |
import string
|
| 5 |
|
|
|
|
| 6 |
class SequenceLoader:
|
| 7 |
def __init__(self):
|
| 8 |
self.sequence_dict = None
|
|
@@ -38,6 +39,12 @@ class SequenceLoader:
|
|
| 38 |
self.lang = lang
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def remove_punctuation_from_seq(seq):
|
| 42 |
english_punctuations = string.punctuation
|
| 43 |
bangla_punctuations = "৷-–—’‘৳…।"
|
|
@@ -157,6 +164,29 @@ class Dataset:
|
|
| 157 |
seq_processor.pad()
|
| 158 |
self.dataset_dict = seq_processor.get_dict()
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
def get_dict(self):
|
| 161 |
return self.dataset_dict
|
| 162 |
|
|
@@ -167,4 +197,6 @@ if __name__ == "__main__":
|
|
| 167 |
dataset_dict = dataset_object.get_dict()
|
| 168 |
utils.save_dict("{}/dataset.txt".format(const.data_dir), dataset_dict)
|
| 169 |
dataset_object.process()
|
|
|
|
|
|
|
| 170 |
print(utils.load_dict("{}/dataset.txt".format(const.data_dir)))
|
|
|
|
| 3 |
import utils
|
| 4 |
import string
|
| 5 |
|
| 6 |
+
|
| 7 |
class SequenceLoader:
|
| 8 |
def __init__(self):
|
| 9 |
self.sequence_dict = None
|
|
|
|
| 39 |
self.lang = lang
|
| 40 |
|
| 41 |
|
| 42 |
+
def serialize(src_seq, tar_seq):
|
| 43 |
+
tar_seq_in = tar_seq[:, :-1].to_tensor()
|
| 44 |
+
tar_seq_out = tar_seq[:, 1:].to_tensor()
|
| 45 |
+
return (src_seq, tar_seq_in), tar_seq_out
|
| 46 |
+
|
| 47 |
+
|
| 48 |
def remove_punctuation_from_seq(seq):
|
| 49 |
english_punctuations = string.punctuation
|
| 50 |
bangla_punctuations = "৷-–—’‘৳…।"
|
|
|
|
| 164 |
seq_processor.pad()
|
| 165 |
self.dataset_dict = seq_processor.get_dict()
|
| 166 |
|
| 167 |
+
def pull(self):
|
| 168 |
+
src_lang_train_seqs = self.dataset_dict[self.langs[0]]["train"]
|
| 169 |
+
tar_lang_train_seqs = self.dataset_dict[self.langs[1]]["train"]
|
| 170 |
+
|
| 171 |
+
src_lang_val_seqs = self.dataset_dict[self.langs[0]]["val"]
|
| 172 |
+
tar_lang_val_seqs = self.dataset_dict[self.langs[1]]["val"]
|
| 173 |
+
|
| 174 |
+
train_ds = ((tf.data.Dataset
|
| 175 |
+
.from_tensor_slices((src_lang_train_seqs, tar_lang_train_seqs)))
|
| 176 |
+
.shuffle(const.BUFFER_SIZE)
|
| 177 |
+
.batch(const.BATCH_SIZE))
|
| 178 |
+
|
| 179 |
+
val_ds = (tf.data.Dataset
|
| 180 |
+
.from_tensor_slices(src_lang_val_seqs, tar_lang_val_seqs)
|
| 181 |
+
.shuffle(const.BUFFER_SIZE)
|
| 182 |
+
.batch(const.BATCH_SIZE))
|
| 183 |
+
|
| 184 |
+
train_ds = train_ds.map(serialize, tf.data.AUTOTUNE)
|
| 185 |
+
val_ds = val_ds.map(serialize, tf.data.AUTOTUNE)
|
| 186 |
+
|
| 187 |
+
return trainset, valset
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
def get_dict(self):
|
| 191 |
return self.dataset_dict
|
| 192 |
|
|
|
|
| 197 |
dataset_dict = dataset_object.get_dict()
|
| 198 |
utils.save_dict("{}/dataset.txt".format(const.data_dir), dataset_dict)
|
| 199 |
dataset_object.process()
|
| 200 |
+
trainset, valset = dataset_object.pull()
|
| 201 |
+
|
| 202 |
print(utils.load_dict("{}/dataset.txt".format(const.data_dir)))
|
src/pipes/models.py
CHANGED
|
@@ -34,38 +34,6 @@ class Seq2Seq:
|
|
| 34 |
outputs = self.output_layer(decoder_outputs)
|
| 35 |
self.model = tf.keras.Model([encoder_inputs, decoder_inputs], outputs)
|
| 36 |
|
| 37 |
-
def run(self, encoder_input_data, decoder_input_data, val_encoder_input_data, val_decoder_input_data):
|
| 38 |
-
self.model.compile(
|
| 39 |
-
optimizer=self.optimizer,
|
| 40 |
-
loss=self.loss,
|
| 41 |
-
metrics=self.metrics
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
decoder_target_data = [[sentence[1:] + [0]] for sentence in decoder_input_data]
|
| 45 |
-
val_decoder_target_data = [[sentence[1:] + [0]] for sentence in val_decoder_input_data]
|
| 46 |
-
|
| 47 |
-
self.model.fit(
|
| 48 |
-
([encoder_input_data, decoder_input_data]),
|
| 49 |
-
decoder_target_data,
|
| 50 |
-
batch_size=self.batch_size,
|
| 51 |
-
epochs=self.epochs,
|
| 52 |
-
validation_data=([val_encoder_input_data, val_decoder_input_data], val_decoder_target_data)
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
def get(self):
|
| 56 |
return self.model
|
| 57 |
|
| 58 |
-
def set_epochs(self, epochs):
|
| 59 |
-
self.epochs = epochs
|
| 60 |
-
|
| 61 |
-
def set_batch_size(self, batch_size):
|
| 62 |
-
self.batch_size = batch_size
|
| 63 |
-
|
| 64 |
-
def set_loss(self, loss):
|
| 65 |
-
self.loss = loss
|
| 66 |
-
|
| 67 |
-
def set_optimizer(self, optimizer):
|
| 68 |
-
self.optimizer = optimizer
|
| 69 |
-
|
| 70 |
-
def set_metric(self, metrics):
|
| 71 |
-
self.metrics = metrics
|
|
|
|
| 34 |
outputs = self.output_layer(decoder_outputs)
|
| 35 |
self.model = tf.keras.Model([encoder_inputs, decoder_inputs], outputs)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def get(self):
|
| 38 |
return self.model
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|