MasumBhuiyan's picture
Seq2Seq model implemented
c5dc1d4
import random
from pipes import const
from pipes import utils
import string
import tensorflow as tf
import numpy as np
class SequenceLoader:
def __init__(self):
self.sequence_dict = None
self.shuffled_sequences = None
self.shuffled_indices = None
self.sequences = None
self.max_seq_length = None
self.vocab = None
self.lang = None
def pack(self):
self.sequences = utils.read_file("{}/raw/{}.txt".format(const.data_dir, self.lang))
examples_count = len(self.sequences)
split_index = int(examples_count * 0.80)
if self.shuffled_indices is None:
self.shuffled_indices = list(range(examples_count))
random.shuffle(self.shuffled_indices)
self.shuffled_sequences = [self.sequences[i] for i in self.shuffled_indices]
self.sequence_dict = dict(
train=self.shuffled_sequences[:split_index],
val=self.shuffled_sequences[split_index:],
count=examples_count,
)
def get_dict(self):
return self.sequence_dict
def set_lang(self, lang):
self.lang = lang
def serialize(src_seq, tar_seq):
tar_seq_in = tf.convert_to_tensor(tar_seq[:, :-1])
tar_seq_out = tf.convert_to_tensor(tar_seq[:, 1:])
src_seq = tf.convert_to_tensor(src_seq)
return (src_seq, tar_seq_in), tar_seq_out
def remove_punctuation_from_seq(seq):
english_punctuations = string.punctuation
bangla_punctuations = "৷-–—’‘৳…।"
all_punctuations = english_punctuations + bangla_punctuations
cleaned_seq = ''.join([char for char in seq if char not in all_punctuations])
cleaned_seq = cleaned_seq.strip()
cleaned_seq = ' '.join(cleaned_seq.split())
return cleaned_seq
def add_start_end_tags_seq(sequence):
return '<SOS> ' + sequence + ' <EOS>'
def pad_sequence(sequence, max_seq_len, padding_token=0):
padded_sequence = sequence[:max_seq_len] + [padding_token] * (max_seq_len - len(sequence))
return padded_sequence
class SequenceProcessor:
def __init__(self, _dataset_dict):
self.max_seq_len = 0
self.lang = None
self.dataset_dict = _dataset_dict
self.vocab = None
def remove_punctuation(self):
for i in range(len(self.dataset_dict[self.lang]["train"])):
self.dataset_dict[self.lang]["train"][i] = remove_punctuation_from_seq(
self.dataset_dict[self.lang]["train"][i])
for i in range(len(self.dataset_dict[self.lang]["val"])):
self.dataset_dict[self.lang]["val"][i] = remove_punctuation_from_seq(
self.dataset_dict[self.lang]["val"][i])
def build_vocab(self):
vocab = set()
for i in range(len(self.dataset_dict[self.lang]["train"])):
seq = self.dataset_dict[self.lang]["train"][i]
vocab.update(seq.split())
for i in range(len(self.dataset_dict[self.lang]["val"])):
seq = self.dataset_dict[self.lang]["val"][i]
vocab.update(seq.split())
self.vocab = sorted(list(vocab))
self.dataset_dict[self.lang]["vocab"] = self.vocab
self.dataset_dict[self.lang]["vocab_size"] = len(self.vocab)
def add_start_end_tags(self):
for i in range(len(self.dataset_dict[self.lang]["train"])):
self.dataset_dict[self.lang]["train"][i] = add_start_end_tags_seq(
self.dataset_dict[self.lang]["train"][i])
self.max_seq_len = max(len(self.dataset_dict[self.lang]["train"][i].split()), self.max_seq_len)
for i in range(len(self.dataset_dict[self.lang]["val"])):
self.dataset_dict[self.lang]["val"][i] = add_start_end_tags_seq(
self.dataset_dict[self.lang]["val"][i])
self.max_seq_len = max(len(self.dataset_dict[self.lang]["val"][i].split()), self.max_seq_len)
self.dataset_dict[self.lang]["max_seq_len"] = self.max_seq_len
def tokenize(self):
for i in range(len(self.dataset_dict[self.lang]["train"])):
seq = self.dataset_dict[self.lang]["train"][i]
tokens = []
for word in seq.split():
tokens.append(self.vocab.index(word))
self.dataset_dict[self.lang]["train"][i] = tokens
for i in range(len(self.dataset_dict[self.lang]["val"])):
seq = self.dataset_dict[self.lang]["val"][i]
tokens = []
for word in seq.split():
tokens.append(self.vocab.index(word))
self.dataset_dict[self.lang]["val"][i] = tokens
def pad(self, max_seq_len=const.MAX_SEQ_LEN):
for i in range(len(self.dataset_dict[self.lang]["train"])):
self.dataset_dict[self.lang]["train"][i] = pad_sequence(
sequence=self.dataset_dict[self.lang]["train"][i], max_seq_len=max_seq_len)
for i in range(len(self.dataset_dict[self.lang]["val"])):
self.dataset_dict[self.lang]["val"][i] = pad_sequence(sequence=self.dataset_dict[self.lang]["val"][i],
max_seq_len=self.max_seq_len)
def set_lang(self, lang):
self.lang = lang
self.max_seq_len = 0
def get_dict(self):
return self.dataset_dict
class Dataset:
def __init__(self, langs):
self.langs = langs
self.dataset_dict = {}
def pack(self):
seq_loader = SequenceLoader()
for lang in self.langs:
seq_loader.set_lang(lang)
seq_loader.pack()
self.dataset_dict[lang] = seq_loader.get_dict()
def process(self):
seq_processor = SequenceProcessor(self.dataset_dict)
for lang in self.langs:
seq_processor.set_lang(lang)
seq_processor.remove_punctuation()
seq_processor.add_start_end_tags()
seq_processor.build_vocab()
seq_processor.tokenize()
seq_processor.pad()
self.dataset_dict = seq_processor.get_dict()
def pull(self):
src_lang_train_seqs = np.array(self.dataset_dict[self.langs[0]]["train"])
tar_lang_train_seqs = np.array(self.dataset_dict[self.langs[1]]["train"])
src_lang_val_seqs = np.array(self.dataset_dict[self.langs[0]]["val"])
tar_lang_val_seqs = np.array(self.dataset_dict[self.langs[1]]["val"])
train_ds = ((tf.data.Dataset
.from_tensor_slices((src_lang_train_seqs, tar_lang_train_seqs)))
.shuffle(const.BUFFER_SIZE)
.batch(const.BATCH_SIZE))
val_ds = ((tf.data.Dataset
.from_tensor_slices((src_lang_val_seqs, tar_lang_val_seqs)))
.shuffle(const.BUFFER_SIZE)
.batch(const.BATCH_SIZE))
train_ds = train_ds.map(serialize, tf.data.AUTOTUNE)
val_ds = val_ds.map(serialize, tf.data.AUTOTUNE)
return train_ds, val_ds
def get_dict(self):
return self.dataset_dict
if __name__ == "__main__":
dataset_object = Dataset(const.langs)
dataset_object.pack()
dataset_dict = dataset_object.get_dict()
utils.save_dict("{}/dataset.txt".format(const.data_dir), dataset_dict)
dataset_object.process()
trainset, valset = dataset_object.pull()
print(utils.load_dict("{}/dataset.txt".format(const.data_dir)))