Spaces:
Sleeping
Sleeping
| ''' Import Modules ''' | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from model import Encoder, Decoder, Attention, Seq2Seq | |
| from utils import Lang | |
| import pickle | |
| import re | |
| import gradio as gr | |
| ''' Hyperparameters ''' | |
| DEVICE = 'cpu' | |
| EMBED_DIM = 100 | |
| HIDDEN_SIZE = 128 | |
| BATCH_SIZE = 1 | |
| ''' Setup ''' | |
| en_lang_path = 'en_vocab_obj.pkl' | |
| fr_lang_path = 'fr_vocab_obj.pkl' | |
| with open(en_lang_path, 'rb') as file: | |
| en_lang = pickle.load(file) | |
| with open(fr_lang_path, 'rb') as file: | |
| fr_lang = pickle.load(file) | |
| enc = Encoder(en_lang.n_words, EMBED_DIM, HIDDEN_SIZE, HIDDEN_SIZE) | |
| attn = Attention(HIDDEN_SIZE, HIDDEN_SIZE) | |
| dec = Decoder(EMBED_DIM, fr_lang.n_words, HIDDEN_SIZE, HIDDEN_SIZE, attn) | |
| net = Seq2Seq(enc, dec) | |
| weights_path = 'net_weights_1.pth' | |
| net.load_state_dict(torch.load(weights_path, map_location='cpu')) | |
| regex_pattern = r"[\W\s\d]+" | |
| eng_prefixes = [ | |
| "i will", "i ll" | |
| "i am ", "i m", | |
| "i have", "i ve", | |
| "he is", "he s", | |
| "she is", "she s", | |
| "you are", "you re", | |
| "we are", "we re", | |
| "they are", "they re", | |
| "i did", "i d" | |
| ] | |
| def preprocess(sentence: str): | |
| sequence = [] | |
| for word in sentence.lower().split(' '): | |
| word = re.sub(regex_pattern, ' ', word).strip() | |
| if word in eng_prefixes: | |
| index = eng_prefixes.index(word) | |
| word = eng_prefixes[index-1] | |
| for subword in word.split(' '): | |
| if subword: | |
| sequence.append(subword.strip()) | |
| elif word != ' ' and word: | |
| sequence.append(word.strip()) | |
| return sequence | |
| def main(sentence: str): | |
| tokens = preprocess(sentence) | |
| indices = [en_lang.word_to_index.get(word, 2) for word in tokens] | |
| indices = [0] + indices + [1] | |
| source = torch.tensor(indices).unsqueeze(0) | |
| target = torch.zeros(BATCH_SIZE, 10, fr_lang.n_words) | |
| with torch.inference_mode(): | |
| pred = net(source, target, 0) | |
| pred = torch.softmax(pred, dim=2) | |
| pred = torch.argmax(pred, dim=2).squeeze(0).tolist() | |
| num_to_word = [fr_lang.index_to_word[idx] for idx in pred] | |
| if 'EOS' in num_to_word: | |
| idx = num_to_word.index('EOS') | |
| num_to_word = num_to_word[:idx] | |
| if 'SOS' in num_to_word: | |
| num_to_word.remove('SOS') | |
| return ' '.join(num_to_word) | |
| app = gr.Interface( | |
| fn=main, | |
| inputs=gr.Textbox(), | |
| outputs=gr.Textbox() | |
| ) | |
| app.launch() | |