--- license: mit language: - vi pipeline_tag: translation --- ## Dataset The model is trained on a high-quality dataset for English-Vietnamese translation: [![GitHub Repo](https://img.shields.io/badge/GitHub-stefan--it%2Fnmt--en--vi-blue?style=flat&logo=github)](https://github.com/stefan-it/nmt-en-vi) ## Usage ```python import tensorflow as tf from translator import Translator from utils import tokenizer_utils from utils.preprocessing import input_processing, output_processing from models.transformer import Transformer from models.encoder import Encoder from models.decoder import Decoder from models.layers import EncoderLayer, DecoderLayer, MultiHeadAttention, point_wise_feed_forward_network from models.utils import masked_loss, masked_accuracy def main(sentence, model): # Load tokenizers en_tokenizer, vi_tokenizer = tokenizer_utils.load_tokenizers() # Update to include tokenizers.tokenizer_utils # Create translator translator = Translator(en_tokenizer, vi_tokenizer, loaded_model) # Process and translate the input sentence processed_sentence = input_processing(sentence) translated_text = translator(processed_sentence) translated_text = output_processing(translated_text) print("Input:", processed_sentence) print("Translated:", translated_text) if __name__ == "__main__": # Example sentence sentence = """ For at least six centuries, residents along a lake in the mountains of central Japan have marked the depth of winter by celebrating the return of a natural phenomenon once revered as the trail of a wandering god. """ # Define custom objects for model loading custom_objects = { 'Transformer': Transformer, 'Encoder': Encoder, 'Decoder': Decoder, 'EncoderLayer': EncoderLayer, 'DecoderLayer': DecoderLayer, 'MultiHeadAttention': MultiHeadAttention, 'point_wise_feed_forward_network': point_wise_feed_forward_network, 'masked_loss': masked_loss, 'masked_accuracy': masked_accuracy } # Load the model loaded_model = tf.keras.models.load_model('ckpts/en_vi_translation.keras', custom_objects=custom_objects) main(sentence=sentence, model=loaded_model) ```