Spaces:
Build error
Build error
| import streamlit as st | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.utils.data import Dataset | |
| from transformer import Transformer | |
| english_file = 'dataset/english.txt' | |
| spanish_file = 'dataset/spanish.txt' | |
| START_TOKEN = '<START>' | |
| PADDING_TOKEN = '<PADDING>' | |
| END_TOKEN = '<END>' | |
| english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '’', | |
| '‘', ';', '₂', | |
| '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', | |
| ':', '<', '=', '>', '?', '@', | |
| '[', '\\', ']', '^', '_', '`', | |
| 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', | |
| 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', | |
| 'y', 'z', | |
| 'á', 'é', 'í', 'ó', 'ú', 'ñ', 'ü', | |
| '¿', '¡', | |
| 'Á', 'É', 'Í', 'Ó', 'Ú', 'Ñ', 'Ü', | |
| '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN, | |
| 'à', 'è', 'ì', 'ò', 'ù', 'À', 'È', 'Ì', 'Ò', 'Ù', | |
| 'â', 'ê', 'î', 'ô', 'û', 'Â', 'Ê', 'Î', 'Ô', 'Û', | |
| 'ä', 'ë', 'ï', 'ö', 'ü', 'Ä', 'Ë', 'Ï', 'Ö', | |
| 'ã', 'õ', 'Ã', 'Õ', | |
| 'ā', 'ē', 'ī', 'ō', 'ū', 'Ā', 'Ē', 'Ī', 'Ō', 'Ū', | |
| 'ą', 'ę', 'į', 'ǫ', 'ų', 'Ą', 'Ę', 'Į', 'Ǫ', 'Ų', | |
| 'ç', 'Ç', 'ş', 'Ş', 'ğ', 'Ğ', 'ń', 'Ń', 'ś', 'Ś', 'ź', 'Ź', 'ż', 'Ż', | |
| 'č', 'Č', 'ć', 'Ć', 'đ', 'Đ', 'ł', 'Ł', 'ř', 'Ř', 'š', 'Š', 'ť', 'Ť', | |
| 'ý', 'ÿ', 'Ý', 'Ÿ', 'ž', 'Ž', 'ß', 'œ', 'Œ', 'æ', 'Æ', 'å', 'Å', 'ø', 'Ø', 'å', 'Å', | |
| 'æ', 'Æ', 'œ', 'Œ'] | |
| spanish_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '’', | |
| '‘', ';', '₂', | |
| '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', | |
| ':', '<', '=', '>', '?', '@', | |
| '[', '\\', ']', '^', '_', '`', | |
| 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', | |
| 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', | |
| 'y', 'z', | |
| 'á', 'é', 'í', 'ó', 'ú', 'ñ', 'ü', | |
| '¿', '¡', | |
| 'Á', 'É', 'Í', 'Ó', 'Ú', 'Ñ', 'Ü', | |
| '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN, | |
| 'à', 'è', 'ì', 'ò', 'ù', 'À', 'È', 'Ì', 'Ò', 'Ù', | |
| 'â', 'ê', 'î', 'ô', 'û', 'Â', 'Ê', 'Î', 'Ô', 'Û', | |
| 'ä', 'ë', 'ï', 'ö', 'ü', 'Ä', 'Ë', 'Ï', 'Ö', | |
| 'ã', 'õ', 'Ã', 'Õ', | |
| 'ā', 'ē', 'ī', 'ō', 'ū', 'Ā', 'Ē', 'Ī', 'Ō', 'Ū', | |
| 'ą', 'ę', 'į', 'ǫ', 'ų', 'Ą', 'Ę', 'Į', 'Ǫ', 'Ų', | |
| 'ç', 'Ç', 'ş', 'Ş', 'ğ', 'Ğ', 'ń', 'Ń', 'ś', 'Ś', 'ź', 'Ź', 'ż', 'Ż', | |
| 'č', 'Č', 'ć', 'Ć', 'đ', 'Đ', 'ł', 'Ł', 'ř', 'Ř', 'š', 'Š', 'ť', 'Ť', | |
| 'ý', 'ÿ', 'Ý', 'Ÿ', 'ž', 'Ž', 'ß', 'œ', 'Œ', 'æ', 'Æ', 'å', 'Å', 'ø', 'Ø', 'å', 'Å', | |
| 'æ', 'Æ', 'œ', 'Œ'] | |
| index_to_english = {k: v for k, v in enumerate(english_vocabulary)} | |
| english_to_index = {v: k for k, v in enumerate(english_vocabulary)} | |
| index_to_spanish = {k: v for k, v in enumerate(spanish_vocabulary)} | |
| spanish_to_index = {v: k for k, v in enumerate(spanish_vocabulary)} | |
| d_model = 512 | |
| batch_size = 30 | |
| ffn_hidden = 2048 | |
| num_heads = 8 | |
| drop_prob = 0.1 | |
| num_layers = 1 | |
| max_sequence_length = 200 | |
| es_vocab_size = len(spanish_vocabulary) | |
| transformer = Transformer(d_model, | |
| ffn_hidden, | |
| num_heads, | |
| drop_prob, | |
| num_layers, | |
| max_sequence_length, | |
| es_vocab_size, | |
| english_to_index, | |
| spanish_to_index, | |
| START_TOKEN, | |
| END_TOKEN, | |
| PADDING_TOKEN) | |
| class TextDataset(Dataset): | |
| def __init__(self, english_sentences, spanish_sentences): | |
| self.english_sentences = english_sentences | |
| self.spanish_sentences = spanish_sentences | |
| def __len__(self): | |
| return len(self.english_sentences) | |
| def __getitem__(self, idx): | |
| return self.english_sentences[idx], self.spanish_sentences[idx] | |
| # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| device = "cpu" | |
| NEG_INFTY = -1e9 | |
| def create_masks(eng_batch, kn_batch): | |
| num_sentences = len(eng_batch) | |
| look_ahead_mask = torch.full([max_sequence_length, max_sequence_length], True) | |
| look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1) | |
| encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length], False) | |
| decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length], False) | |
| decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length], False) | |
| for idx in range(num_sentences): | |
| eng_sentence_length, kn_sentence_length = len(eng_batch[idx]), len(kn_batch[idx]) | |
| eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length) | |
| kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_sequence_length) | |
| encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True | |
| encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True | |
| decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True | |
| decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True | |
| decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True | |
| decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True | |
| encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0) | |
| decoder_self_attention_mask = torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0) | |
| decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0) | |
| return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask | |
| def translate(eng_sentence): | |
| d_model = 512 | |
| batch_size = 30 | |
| ffn_hidden = 2048 | |
| num_heads = 8 | |
| drop_prob = 0.1 | |
| num_layers = 1 | |
| max_sequence_length = 200 | |
| es_vocab_size = len(spanish_vocabulary) | |
| transformer = Transformer(d_model, | |
| ffn_hidden, | |
| num_heads, | |
| drop_prob, | |
| num_layers, | |
| max_sequence_length, | |
| es_vocab_size, | |
| english_to_index, | |
| spanish_to_index, | |
| START_TOKEN, | |
| END_TOKEN, | |
| PADDING_TOKEN) | |
| transformer.load_state_dict(torch.load("englishTOspanish.pt", map_location=torch.device('cpu'))) | |
| transformer.eval() | |
| eng_sentence = (eng_sentence,) | |
| es_sentence = ("",) | |
| for word_counter in range(max_sequence_length): | |
| encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks( | |
| eng_sentence, es_sentence) | |
| predictions = transformer(eng_sentence, | |
| es_sentence, | |
| encoder_self_attention_mask.to(device), | |
| decoder_self_attention_mask.to(device), | |
| decoder_cross_attention_mask.to(device), | |
| enc_start_token=False, | |
| enc_end_token=False, | |
| dec_start_token=True, | |
| dec_end_token=False) | |
| next_token_prob_distribution = predictions[0][word_counter] | |
| next_token_index = torch.argmax(next_token_prob_distribution).item() | |
| next_token = index_to_spanish[next_token_index] | |
| es_sentence = (es_sentence[0] + next_token,) | |
| if next_token == END_TOKEN: | |
| break | |
| return es_sentence[0] | |
| st.title("seq2seq Machine Translation") | |
| st.write("Translate English to Spanish") | |
| st.write("\n") | |
| st.write("Some example sentences:") | |
| st.write("i'm happy to see you here") | |
| st.write("i have nothing to do with it") | |
| st.write("what did you say yesterday?") | |
| st.write("\n") | |
| input_text = st.text_area("Enter English text:") | |
| if st.button("Translate"): | |
| if input_text.strip() == "": | |
| st.warning("Please enter some text.") | |
| else: | |
| translated_text = translate(input_text) | |
| st.write("Your Text (English):") | |
| st.title(input_text) | |
| st.write("Translated Text (Spanish):") | |
| st.title(translated_text[:-5]) | |