| from transformer import Transformer |
| import torch |
| import numpy as np |
| import chardet |
| import matplotlib.pyplot as plt |
| from torch import nn |
| english_file = r'C:\Users\haris\Downloads\eng_marathi\train.en' |
| marathi_file = r'C:\Users\haris\Downloads\eng_marathi\train.mr' |
|
|
| |
|
|
| START_TOKEN = '<START>' |
| PADDING_TOKEN = '<PADDING>' |
| END_TOKEN = '<END>' |
|
|
| marathi_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', |
| '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '=', '>', '?', 'ˌ', |
| 'ँ', 'ఆ', 'ఇ', 'ా', 'ి', 'ీ', 'ు', 'ూ', |
| 'अ', 'आ', 'इ', 'ई', 'उ', 'ऊ', 'ऋ', 'ॠ', 'ऌ', 'ऎ', 'ए', 'ऐ', 'ऒ', 'ओ', 'औ', |
| 'क', 'ख', 'ग', 'घ', 'ङ', |
| 'च', 'छ', 'ज', 'झ', 'ञ', |
| 'ट', 'ठ', 'ड', 'ढ', 'ण', |
| 'त', 'थ', 'द', 'ध', 'न', |
| 'प', 'फ', 'ब', 'भ', 'म', |
| 'य', 'र', 'ऱ', 'ल', 'ळ', 'व', 'श', 'ष', 'स', 'ह', |
| '़', 'ऽ', 'ा', 'ि', 'ी', 'ु', 'ू', 'ृ', 'ॄ', 'ॅ', 'े', 'ै', 'ॉ', 'ो', 'ौ', '्', 'ॐ', '।', '॥', '॰', 'ॱ', PADDING_TOKEN, END_TOKEN] |
|
|
| english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', |
| '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', |
| ':', '<', '=', '>', '?', '@', |
| '[', '\\', ']', '^', '_', '`', |
| 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', |
| 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', |
| 'y', 'z', |
| '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN] |
| index_to_marathi = {k:v for k,v in enumerate(marathi_vocabulary)} |
| marathi_to_index = {v:k for k,v in enumerate(marathi_vocabulary)} |
| index_to_english = {k:v for k,v in enumerate(english_vocabulary)} |
| english_to_index = {v:k for k,v in enumerate(english_vocabulary)} |
|
|
| |
| with open(marathi_file, 'rb') as file: |
| raw_data = file.read(10000) |
| result = chardet.detect(raw_data) |
| encoding = result['encoding'] |
| print(f"Detected encoding: {encoding}") |
| |
| with open(marathi_file, 'r', encoding=encoding) as file: |
| marathi_sentences = file.readlines() |
|
|
| |
| with open(english_file, 'r', encoding='utf-8') as file: |
| english_sentences = file.readlines() |
|
|
| |
| TOTAL_SENTENCES = 20000 |
| english_sentences = english_sentences[:TOTAL_SENTENCES] |
| marathi_sentences = marathi_sentences[:TOTAL_SENTENCES] |
| english_sentences = [sentence.rstrip('\n').lower() for sentence in english_sentences] |
| marathi_sentences = [sentence.rstrip('\n') for sentence in marathi_sentences] |
|
|
|
|
| max_sequence_length = 200 |
|
|
| def is_valid_tokens(sentence, vocab): |
| for token in list(set(sentence)): |
| if token not in vocab: |
| return False |
| return True |
|
|
| def is_valid_length(sentence, max_sequence_length): |
| return len(list(sentence)) < (max_sequence_length - 1) |
|
|
| valid_sentence_indicies = [] |
| for index in range(len(marathi_sentences)): |
| marathi_sentence, english_sentence = marathi_sentences[index], english_sentences[index] |
| if is_valid_length(marathi_sentence, max_sequence_length) \ |
| and is_valid_length(english_sentence, max_sequence_length) \ |
| and is_valid_tokens(marathi_sentence, marathi_vocabulary): |
| valid_sentence_indicies.append(index) |
|
|
| print(f"Number of sentences: {len(marathi_sentences)}") |
| print(f"Number of valid sentences: {len(valid_sentence_indicies)}") |
|
|
| marathi_sentences = [marathi_sentences[i] for i in valid_sentence_indicies] |
| english_sentences = [english_sentences[i] for i in valid_sentence_indicies] |
|
|
|
|
|
|
| d_model = 512 |
| batch_size = 64 |
| ffn_hidden = 2048 |
| num_heads = 8 |
| drop_prob = 0.1 |
| num_layers = 4 |
| max_sequence_length = 200 |
| mr_vocab_size = len(marathi_vocabulary) |
|
|
| transformer = Transformer(d_model, |
| ffn_hidden, |
| num_heads, |
| drop_prob, |
| num_layers, |
| max_sequence_length, |
| mr_vocab_size, |
| english_to_index, |
| marathi_to_index, |
| START_TOKEN, |
| END_TOKEN, |
| PADDING_TOKEN) |
|
|
| from torch.utils.data import Dataset, DataLoader |
|
|
| class TextDataset(Dataset): |
|
|
| def __init__(self, english_sentences, marathi_sentences): |
| self.english_sentences = english_sentences |
| self.marathi_sentences = marathi_sentences |
|
|
| def __len__(self): |
| return len(self.english_sentences) |
|
|
| def __getitem__(self, idx): |
| return self.english_sentences[idx], self.marathi_sentences[idx] |
|
|
|
|
| dataset = TextDataset(english_sentences, marathi_sentences) |
| train_loader = DataLoader(dataset, batch_size) |
| iterator = iter(train_loader) |
| from torch import nn |
|
|
| criterian = nn.CrossEntropyLoss(ignore_index=marathi_to_index[PADDING_TOKEN], |
| reduction='none') |
|
|
| |
| for params in transformer.parameters(): |
| if params.dim() > 1: |
| nn.init.xavier_uniform_(params) |
|
|
| optim = torch.optim.Adam(transformer.parameters(), lr=1e-4) |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
| NEG_INFTY = -1e9 |
|
|
| def create_masks(eng_batch, mr_batch): |
| num_sentences = len(eng_batch) |
| look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True) |
| look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1) |
| encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False) |
| decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False) |
| decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False) |
|
|
| for idx in range(num_sentences): |
| eng_sentence_length, mr_sentence_length = len(eng_batch[idx]), len(mr_batch[idx]) |
| eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length) |
| mr_chars_to_padding_mask = np.arange(mr_sentence_length + 1, max_sequence_length) |
| encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True |
| encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True |
| decoder_padding_mask_self_attention[idx, :, mr_chars_to_padding_mask] = True |
| decoder_padding_mask_self_attention[idx, mr_chars_to_padding_mask, :] = True |
| decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True |
| decoder_padding_mask_cross_attention[idx, mr_chars_to_padding_mask, :] = True |
|
|
| encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0) |
| decoder_self_attention_mask = torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0) |
| decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0) |
| return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask |
| transformer.train() |
| transformer.to(device) |
| num_epochs = 100 |
| epoch_losses = [] |
|
|
| for epoch in range(num_epochs): |
| print(f"Epoch {epoch}") |
| total_loss = 0 |
| count_batches = 0 |
| iterator = iter(train_loader) |
| for batch_num, batch in enumerate(iterator): |
| transformer.train() |
| eng_batch, mr_batch = batch |
| encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, mr_batch) |
| optim.zero_grad() |
| mr_predictions = transformer(eng_batch, |
| mr_batch, |
| encoder_self_attention_mask.to(device), |
| decoder_self_attention_mask.to(device), |
| decoder_cross_attention_mask.to(device), |
| enc_start_token=False, |
| enc_end_token=False, |
| dec_start_token=True, |
| dec_end_token=True) |
| labels = transformer.decoder.sentence_embedding.batch_tokenize(mr_batch, start_token=False, end_token=True) |
| loss = criterian( |
| mr_predictions.view(-1, mr_vocab_size).to(device), |
| labels.view(-1).to(device) |
| ).to(device) |
| valid_indicies = torch.where(labels.view(-1) == marathi_to_index[PADDING_TOKEN], False, True) |
| loss = loss.sum() / valid_indicies.sum() |
| loss.backward() |
| optim.step() |
| total_loss += loss.item() |
| count_batches += 1 |
| |
| if batch_num % 100 == 0: |
| print(f"Iteration {batch_num} : {loss.item()}") |
| print(f"English: {eng_batch[0]}") |
| print(f"marathi Translation: {mr_batch[0]}") |
| mr_sentence_predicted = torch.argmax(mr_predictions[0], axis=1) |
| predicted_sentence = "" |
| for idx in mr_sentence_predicted: |
| if idx == marathi_to_index[END_TOKEN]: |
| break |
| predicted_sentence += index_to_marathi[idx.item()] |
| print(f"marathi Prediction: {predicted_sentence}") |
| average_loss = total_loss / count_batches |
| epoch_losses.append(average_loss) |
| print(f"Average Loss for Epoch {epoch}: {average_loss}") |
|
|
| transformer.eval() |
| mr_sentence = ("",) |
| eng_sentence = ("should we go to the mall?",) |
| for word_counter in range(max_sequence_length): |
| encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, mr_sentence) |
| predictions = transformer(eng_sentence, |
| mr_sentence, |
| encoder_self_attention_mask.to(device), |
| decoder_self_attention_mask.to(device), |
| decoder_cross_attention_mask.to(device), |
| enc_start_token=False, |
| enc_end_token=False, |
| dec_start_token=True, |
| dec_end_token=False) |
| next_token_prob_distribution = predictions[0][word_counter] |
| next_token_index = torch.argmax(next_token_prob_distribution).item() |
| next_token = index_to_marathi[next_token_index] |
| mr_sentence = (mr_sentence[0] + next_token, ) |
| if next_token == END_TOKEN: |
| break |
|
|
| print(f"Evaluation translation (should we go to the mall?) : {mr_sentence}") |
| print("-------------------------------------------") |
| |