''' Import Modules ''' import random import torch import torch.nn as nn import torch.nn.functional as F from model import Encoder, Decoder, Attention, Seq2Seq from utils import Lang import pickle import re import gradio as gr ''' Hyperparameters ''' DEVICE = 'cpu' EMBED_DIM = 100 HIDDEN_SIZE = 128 BATCH_SIZE = 1 ''' Setup ''' en_lang_path = 'en_vocab_obj.pkl' fr_lang_path = 'fr_vocab_obj.pkl' with open(en_lang_path, 'rb') as file: en_lang = pickle.load(file) with open(fr_lang_path, 'rb') as file: fr_lang = pickle.load(file) enc = Encoder(en_lang.n_words, EMBED_DIM, HIDDEN_SIZE, HIDDEN_SIZE) attn = Attention(HIDDEN_SIZE, HIDDEN_SIZE) dec = Decoder(EMBED_DIM, fr_lang.n_words, HIDDEN_SIZE, HIDDEN_SIZE, attn) net = Seq2Seq(enc, dec) weights_path = 'net_weights_1.pth' net.load_state_dict(torch.load(weights_path, map_location='cpu')) regex_pattern = r"[\W\s\d]+" eng_prefixes = [ "i will", "i ll" "i am ", "i m", "i have", "i ve", "he is", "he s", "she is", "she s", "you are", "you re", "we are", "we re", "they are", "they re", "i did", "i d" ] def preprocess(sentence: str): sequence = [] for word in sentence.lower().split(' '): word = re.sub(regex_pattern, ' ', word).strip() if word in eng_prefixes: index = eng_prefixes.index(word) word = eng_prefixes[index-1] for subword in word.split(' '): if subword: sequence.append(subword.strip()) elif word != ' ' and word: sequence.append(word.strip()) return sequence def main(sentence: str): tokens = preprocess(sentence) indices = [en_lang.word_to_index.get(word, 2) for word in tokens] indices = [0] + indices + [1] source = torch.tensor(indices).unsqueeze(0) target = torch.zeros(BATCH_SIZE, 10, fr_lang.n_words) with torch.inference_mode(): pred = net(source, target, 0) pred = torch.softmax(pred, dim=2) pred = torch.argmax(pred, dim=2).squeeze(0).tolist() num_to_word = [fr_lang.index_to_word[idx] for idx in pred] if 'EOS' in num_to_word: idx = num_to_word.index('EOS') num_to_word = num_to_word[:idx] if 'SOS' in num_to_word: num_to_word.remove('SOS') return ' '.join(num_to_word) app = gr.Interface( fn=main, inputs=gr.Textbox(), outputs=gr.Textbox() ) app.launch()