Spaces:
Sleeping
Sleeping
File size: 2,466 Bytes
ed8878f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
''' Import Modules '''
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import Encoder, Decoder, Attention, Seq2Seq
from utils import Lang
import pickle
import re
import gradio as gr
''' Hyperparameters '''
DEVICE = 'cpu'
EMBED_DIM = 100
HIDDEN_SIZE = 128
BATCH_SIZE = 1
''' Setup '''
en_lang_path = 'en_vocab_obj.pkl'
fr_lang_path = 'fr_vocab_obj.pkl'
with open(en_lang_path, 'rb') as file:
en_lang = pickle.load(file)
with open(fr_lang_path, 'rb') as file:
fr_lang = pickle.load(file)
enc = Encoder(en_lang.n_words, EMBED_DIM, HIDDEN_SIZE, HIDDEN_SIZE)
attn = Attention(HIDDEN_SIZE, HIDDEN_SIZE)
dec = Decoder(EMBED_DIM, fr_lang.n_words, HIDDEN_SIZE, HIDDEN_SIZE, attn)
net = Seq2Seq(enc, dec)
weights_path = 'net_weights_1.pth'
net.load_state_dict(torch.load(weights_path, map_location='cpu'))
regex_pattern = r"[\W\s\d]+"
eng_prefixes = [
"i will", "i ll"
"i am ", "i m",
"i have", "i ve",
"he is", "he s",
"she is", "she s",
"you are", "you re",
"we are", "we re",
"they are", "they re",
"i did", "i d"
]
def preprocess(sentence: str):
sequence = []
for word in sentence.lower().split(' '):
word = re.sub(regex_pattern, ' ', word).strip()
if word in eng_prefixes:
index = eng_prefixes.index(word)
word = eng_prefixes[index-1]
for subword in word.split(' '):
if subword:
sequence.append(subword.strip())
elif word != ' ' and word:
sequence.append(word.strip())
return sequence
def main(sentence: str):
tokens = preprocess(sentence)
indices = [en_lang.word_to_index.get(word, 2) for word in tokens]
indices = [0] + indices + [1]
source = torch.tensor(indices).unsqueeze(0)
target = torch.zeros(BATCH_SIZE, 10, fr_lang.n_words)
with torch.inference_mode():
pred = net(source, target, 0)
pred = torch.softmax(pred, dim=2)
pred = torch.argmax(pred, dim=2).squeeze(0).tolist()
num_to_word = [fr_lang.index_to_word[idx] for idx in pred]
if 'EOS' in num_to_word:
idx = num_to_word.index('EOS')
num_to_word = num_to_word[:idx]
if 'SOS' in num_to_word:
num_to_word.remove('SOS')
return ' '.join(num_to_word)
app = gr.Interface(
fn=main,
inputs=gr.Textbox(),
outputs=gr.Textbox()
)
app.launch()
|