Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import argparse | |
| import streamlit as st | |
| import sentencepiece as spm | |
| from utils import utils_cls | |
| from model import BanglaTransformer | |
| from config import config as cfg | |
| torch.manual_seed(0) | |
| # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| device = torch.device('cpu') | |
| uobj = utils_cls(device=device) | |
| __MODULE__ = "Bangla Language Translation" | |
| __MAIL__ = "saifulbrur79@gmail.com" | |
| __MODIFICAIOTN__ = "28/03/2023" | |
| __LICENSE__ = "MIT" | |
| st.write(""" Bangla to English Translation """) | |
| BASE_URL = "./model" | |
| class Bn2EnTranslation: | |
| def __init__(self): | |
| self.bn_tokenizer= os.path.join(BASE_URL , "bn_model.model") | |
| self.en_tokenizer=os.path.join(BASE_URL, 'en_model.model') | |
| self.bn_vocab=os.path.join(BASE_URL,'bn_vocab.pkl') | |
| self.en_vocab=os.path.join(BASE_URL, 'en_vocab.pkl') | |
| self.model= os.path.join(BASE_URL,'pytorch_model.pt') | |
| def read_data(self, data_path): | |
| with open(data_path, "r") as f: | |
| data = f.readlines() | |
| data = list(map(lambda x: [x.split("\t")[0], x.split("\t")[1].replace("\n", "")], data)) | |
| return data | |
| def load_tokenizer(self, tokenizer_path:str = "")->object: | |
| _tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path) | |
| return _tokenizer | |
| def get_vocab(self, BN_VOCAL_PATH:str="", EN_VOCAL_PATH:str=""): | |
| bn_vocal, en_vocal = uobj.load_bn_vocal(BN_VOCAL_PATH), uobj.load_en_vocal(EN_VOCAL_PATH) | |
| return bn_vocal, en_vocal | |
| def load_model(self, model_path:str = "", SRC_VOCAB_SIZE:int=0, TGT_VOCAB_SIZE:int=0): | |
| model = BanglaTransformer( | |
| cfg.NUM_ENCODER_LAYERS, cfg.NUM_DECODER_LAYERS, cfg.EMB_SIZE, SRC_VOCAB_SIZE, | |
| TGT_VOCAB_SIZE, cfg.FFN_HID_DIM, nhead= cfg.NHEAD) | |
| model.to(device) | |
| checkpoint = torch.load(model_path) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| return model | |
| def greedy_decode(self, model, src, src_mask, max_len, start_symbol, eos_index): | |
| src = src.to(device) | |
| src_mask = src_mask.to(device) | |
| memory = model.encode(src, src_mask) | |
| ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device) | |
| for i in range(max_len-1): | |
| memory = memory.to(device) | |
| memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool) | |
| tgt_mask = (uobj.generate_square_subsequent_mask(ys.size(0)) | |
| .type(torch.bool)).to(device) | |
| out = model.decode(ys, memory, tgt_mask) | |
| out = out.transpose(0, 1) | |
| prob = model.generator(out[:, -1]) | |
| _, next_word = torch.max(prob, dim = 1) | |
| next_word = next_word.item() | |
| ys = torch.cat([ys,torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0) | |
| if next_word == eos_index: | |
| break | |
| return ys | |
| def get_bntoen_model(self): | |
| print("Tokenizer Loading ...... : ", end="", flush=True) | |
| bn_tokenizer = self.load_tokenizer(tokenizer_path=self.bn_tokenizer) | |
| print("Done") | |
| print("Vocab Loading ...... : ", end="", flush=True) | |
| bn_vocab, en_vocab = self.get_vocab(BN_VOCAL_PATH=self.bn_vocab, EN_VOCAL_PATH=self.en_vocab) | |
| print("Done") | |
| print("Model Loading ...... : ", end="", flush=True) | |
| model = self.load_model(model_path=self.model, SRC_VOCAB_SIZE=len(bn_vocab), TGT_VOCAB_SIZE=len(en_vocab)) | |
| print("Done") | |
| models = { | |
| "bn_tokenizer" : bn_tokenizer, | |
| "bn_vocab" : bn_vocab, | |
| "en_vocab" : en_vocab, | |
| "model": model | |
| } | |
| return models | |
| def translate(self, text, models): | |
| model = models["model"] | |
| src_vocab = models["bn_vocab"] | |
| tgt_vocab = models["en_vocab"] | |
| src_tokenizer = models["bn_tokenizer"] | |
| src = text | |
| PAD_IDX, BOS_IDX, EOS_IDX= src_vocab['<pad>'], src_vocab['<bos>'], src_vocab['<eos>'] | |
| tokens = [BOS_IDX] + [src_vocab.get_stoi()[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX] | |
| num_tokens = len(tokens) | |
| src = (torch.LongTensor(tokens).reshape(num_tokens, 1) ) | |
| src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) | |
| tgt_tokens = self.greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX, eos_index= EOS_IDX).flatten() | |
| p_text = " ".join([tgt_vocab.get_itos()[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "") | |
| pts = " ".join(list(map(lambda x : x , p_text.replace(" ", "").split("▁")))) | |
| return pts.strip() | |
| # if __name__ == "__main__": | |
| # print(torch.cuda.get_device_name(0)) | |
| text = "এই উপজেলায় ১টি সরকারি কলেজ রয়েছে" | |
| obj = Bn2EnTranslation() | |
| models = obj.get_bntoen_model() | |
| text = st.text_area("Enter some text:এই উপজেলায় ১টি সরকারি কলেজ রয়েছে") | |
| if text: | |
| pre = obj.translate(text, models) | |
| print(f"Input : {text}") | |
| print(f"Prediction : {pre}") | |