Spaces:
Build error
Build error
| import math | |
| import gradio as gr | |
| import lightning as L | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| from tokenizers import Tokenizer | |
| class Translator: | |
| def __init__( | |
| self, | |
| src_tokenizer_ckpt_path, | |
| tgt_tokenizer_ckpt_path, | |
| model_ckpt_path, | |
| ): | |
| self.src_tokenizer = Tokenizer.from_file(src_tokenizer_ckpt_path) | |
| self.tgt_tokenizer = Tokenizer.from_file(tgt_tokenizer_ckpt_path) | |
| self.src_tokenizer.model.dropout = 0 | |
| self.tgt_tokenizer.model.dropout = 0 | |
| self.model = TransformerSeq2Seq.load_from_checkpoint( | |
| model_ckpt_path, | |
| map_location="cpu", | |
| ) | |
| self.model.eval() | |
| def predict(self, src): | |
| tokenized_text = self.src_tokenizer.encode(src) | |
| src = torch.LongTensor(tokenized_text.ids).view(-1, 1) | |
| tgt = self.model.greedy_decode(src, max_len=100) | |
| tgt = tgt.squeeze(1).tolist() | |
| tgt_text = self.tgt_tokenizer.decode(tgt) | |
| return tgt_text | |
| def generate_square_subsequent_mask(sz): | |
| mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1) | |
| mask = ( | |
| mask.float() | |
| .masked_fill(mask == 0, float("-inf")) | |
| .masked_fill(mask == 1, float(0.0)) | |
| ) | |
| return mask | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, embedding_dim, dropout, maxlen=5000): | |
| super(PositionalEncoding, self).__init__() | |
| den = torch.exp( | |
| -torch.arange(0, embedding_dim, 2) * math.log(10000) / embedding_dim | |
| ) | |
| pos = torch.arange(0, maxlen).reshape(maxlen, 1) | |
| pos_embedding = torch.zeros((maxlen, embedding_dim)) | |
| pos_embedding[:, 0::2] = torch.sin(pos * den) | |
| pos_embedding[:, 1::2] = torch.cos(pos * den) | |
| pos_embedding = pos_embedding.unsqueeze(-2) | |
| self.dropout = nn.Dropout(dropout) | |
| self.register_buffer("pos_embedding", pos_embedding) | |
| def forward(self, token_embedding): | |
| return self.dropout( | |
| token_embedding + self.pos_embedding[: token_embedding.size(0), :] | |
| ) | |
| class TransformerSeq2Seq(L.LightningModule): | |
| def __init__( | |
| self, | |
| src_vocab_size, | |
| tgt_vocab_size, | |
| embedding_dim=512, | |
| hidden_dim=512, | |
| dropout=0.1, | |
| nhead=8, | |
| num_layers=3, | |
| batch_size=32, | |
| lr=1e-4, | |
| weight_decay=1e-4, | |
| sos_idx=1, | |
| eos_idx=2, | |
| padding_idx=3, | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.src_embedding = nn.Embedding( | |
| src_vocab_size, | |
| embedding_dim, | |
| padding_idx=padding_idx, | |
| ) | |
| self.tgt_embedding = nn.Embedding( | |
| tgt_vocab_size, | |
| embedding_dim, | |
| padding_idx=padding_idx, | |
| ) | |
| self.positional_encoding = PositionalEncoding( | |
| embedding_dim=embedding_dim, | |
| dropout=dropout, | |
| ) | |
| self.transformer = nn.Transformer( | |
| d_model=embedding_dim, | |
| nhead=nhead, | |
| num_encoder_layers=num_layers, | |
| num_decoder_layers=num_layers, | |
| dim_feedforward=hidden_dim, | |
| dropout=dropout, | |
| ) | |
| self.fc = nn.Linear(embedding_dim, tgt_vocab_size) | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| self.criteria = nn.CrossEntropyLoss() | |
| def forward( | |
| self, | |
| src, | |
| tgt, | |
| src_mask, | |
| tgt_mask, | |
| src_padding_mask, | |
| tgt_padding_mask, | |
| ): | |
| src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5) | |
| tgt = self.tgt_embedding(tgt) * (self.hparams.embedding_dim**0.5) | |
| src = self.positional_encoding(src) | |
| tgt = self.positional_encoding(tgt) | |
| out = self.transformer( | |
| src, | |
| tgt, | |
| src_mask=src_mask, | |
| tgt_mask=tgt_mask, | |
| src_key_padding_mask=src_padding_mask, | |
| tgt_key_padding_mask=tgt_padding_mask, | |
| ) | |
| out = self.fc(out) | |
| return out | |
| def greedy_decode(self, src, max_len): | |
| src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5) | |
| src = self.positional_encoding(src) | |
| memory = self.transformer.encoder(src) | |
| ys = torch.ones(1, 1).fill_(self.hparams.sos_idx).type(torch.long) | |
| for i in range(max_len - 1): | |
| tgt = self.tgt_embedding(ys) * (self.hparams.embedding_dim**0.5) | |
| tgt = self.positional_encoding(tgt) | |
| tgt_mask = generate_square_subsequent_mask(ys.size(0)).type(torch.bool) | |
| out = self.transformer.decoder( | |
| tgt, | |
| memory, | |
| tgt_mask=tgt_mask, | |
| ) | |
| out = self.fc(out) | |
| out = out.transpose(0, 1)[:, -1] | |
| prob = out.softmax(dim=-1) | |
| _, next_word = torch.max(prob, dim=1) | |
| next_word = next_word.item() | |
| ys = torch.cat( | |
| [ys, torch.ones(1, 1).fill_(next_word).type(torch.long)], | |
| dim=0, | |
| ) | |
| if next_word == self.hparams.eos_idx: | |
| break | |
| return ys | |
| def training_step(self, batch, batch_idx): | |
| src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch | |
| tgt_input = tgt[:-1, :] | |
| logits = self( | |
| src, | |
| tgt_input, | |
| src_mask, | |
| tgt_mask, | |
| src_padding_mask, | |
| tgt_padding_mask, | |
| ) | |
| tgt_out = tgt[1:, :] | |
| loss = self.criteria( | |
| logits.reshape(-1, logits.shape[-1]), | |
| tgt_out.reshape(-1), | |
| ) | |
| self.log("train_loss", loss, batch_size=self.hparams.batch_size) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch | |
| tgt_input = tgt[:-1, :] | |
| logits = self( | |
| src, | |
| tgt_input, | |
| src_mask, | |
| tgt_mask, | |
| src_padding_mask, | |
| tgt_padding_mask, | |
| ) | |
| tgt_out = tgt[1:, :] | |
| loss = self.criteria( | |
| logits.reshape(-1, logits.shape[-1]), | |
| tgt_out.reshape(-1), | |
| ) | |
| self.log("val_loss", loss, batch_size=self.hparams.batch_size) | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.AdamW( | |
| self.parameters(), | |
| lr=self.hparams.lr, | |
| weight_decay=self.hparams.weight_decay, | |
| ) | |
| return { | |
| "optimizer": optimizer, | |
| "lr_scheduler": { | |
| "scheduler": torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer=optimizer, | |
| max_lr=self.hparams.lr, | |
| total_steps=self.trainer.estimated_stepping_batches, | |
| ), | |
| "interval": "step", | |
| }, | |
| } | |
| src_tokenizer_ckpt_path = hf_hub_download( | |
| repo_id="SatwikKambham/opus100-en-hi-transformer", | |
| filename="tokenizer-en.json", | |
| ) | |
| tgt_tokenizer_ckpt_path = hf_hub_download( | |
| repo_id="SatwikKambham/opus100-en-hi-transformer", | |
| filename="tokenizer-hi.json", | |
| ) | |
| model_ckpt_path = hf_hub_download( | |
| repo_id="SatwikKambham/opus100-en-hi-transformer", | |
| filename="transformer.ckpt", | |
| ) | |
| classifier = Translator( | |
| src_tokenizer_ckpt_path, | |
| tgt_tokenizer_ckpt_path, | |
| model_ckpt_path, | |
| ) | |
| interface = gr.Interface( | |
| fn=classifier.predict, | |
| inputs=gr.components.Textbox( | |
| label="Source Language (English)", | |
| placeholder="Enter text here...", | |
| ), | |
| outputs=gr.components.Textbox( | |
| label="Target Language (Hindi)", | |
| placeholder="Translation", | |
| ), | |
| examples=[ | |
| ["Hi how are you?"], | |
| ["Today is a very important day."], | |
| ["I like playing the guitar."], | |
| ], | |
| ) | |
| interface.launch() | |