File size: 2,466 Bytes
ed8878f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
''' Import Modules '''
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from model import Encoder, Decoder, Attention, Seq2Seq
from utils import Lang

import pickle
import re
import gradio as gr

''' Hyperparameters '''

DEVICE = 'cpu'
EMBED_DIM = 100
HIDDEN_SIZE = 128
BATCH_SIZE = 1

''' Setup '''

en_lang_path = 'en_vocab_obj.pkl'
fr_lang_path = 'fr_vocab_obj.pkl'

with open(en_lang_path, 'rb') as file:
    en_lang = pickle.load(file)

with open(fr_lang_path, 'rb') as file:
    fr_lang = pickle.load(file)

enc = Encoder(en_lang.n_words, EMBED_DIM, HIDDEN_SIZE, HIDDEN_SIZE)
attn = Attention(HIDDEN_SIZE, HIDDEN_SIZE)
dec = Decoder(EMBED_DIM, fr_lang.n_words, HIDDEN_SIZE, HIDDEN_SIZE, attn)
net = Seq2Seq(enc, dec)

weights_path = 'net_weights_1.pth'
net.load_state_dict(torch.load(weights_path, map_location='cpu'))

regex_pattern = r"[\W\s\d]+"
eng_prefixes =  [
        "i will", "i ll"
        "i am ", "i m",
        "i have", "i ve",
        "he is", "he s",
        "she is", "she s",
        "you are", "you re",
        "we are", "we re",
        "they are", "they re",
        "i did", "i d"
    ]

def preprocess(sentence: str):
    sequence = []
    for word in sentence.lower().split(' '):
        word = re.sub(regex_pattern, ' ', word).strip()
    
        if word in eng_prefixes:
            index = eng_prefixes.index(word)
            word = eng_prefixes[index-1]

            for subword in word.split(' '):
                if subword:
                    sequence.append(subword.strip())

        elif word != ' ' and word:
            sequence.append(word.strip())

    return sequence

def main(sentence: str):
    tokens = preprocess(sentence)
    indices = [en_lang.word_to_index.get(word, 2) for word in tokens]
    indices = [0] + indices + [1]
    source = torch.tensor(indices).unsqueeze(0)

    target = torch.zeros(BATCH_SIZE, 10, fr_lang.n_words)

    with torch.inference_mode():
        pred = net(source, target, 0)
        pred = torch.softmax(pred, dim=2)
        pred = torch.argmax(pred, dim=2).squeeze(0).tolist()

    num_to_word = [fr_lang.index_to_word[idx] for idx in pred]

    if 'EOS' in num_to_word:
        idx = num_to_word.index('EOS')
        num_to_word = num_to_word[:idx]

    if 'SOS' in num_to_word:
        num_to_word.remove('SOS')

    return ' '.join(num_to_word)

app = gr.Interface(
    fn=main,
    inputs=gr.Textbox(),
    outputs=gr.Textbox()
)

app.launch()