| """ |
| Test script for English to Kannada Machine Translation Model |
| This script loads the trained model and performs inference testing |
| """ |
|
|
| import torch |
| import json |
| from tokenizers import Tokenizer |
| from main import Transformer, greedy_decode |
|
|
| |
| START_TOKEN = '<START>' |
| END_TOKEN = '<END>' |
| PADDING_TOKEN = '<PAD>' |
| UNKNOWN_TOKEN = '<UNK>' |
|
|
|
|
| def load_model_and_tokenizers(model_path='best_model.pt', |
| src_tokenizer_path='source_tokenizer.json', |
| tgt_tokenizer_path='target_tokenizer.json'): |
| """ |
| Load the trained model and tokenizers |
| |
| Args: |
| model_path: Path to saved model checkpoint |
| src_tokenizer_path: Path to source (English) tokenizer |
| tgt_tokenizer_path: Path to target (Kannada) tokenizer |
| |
| Returns: |
| model, src_tokenizer, tgt_tokenizer, vocab_info, device |
| """ |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
| |
| |
| print(f"Loading model from {model_path}...") |
| checkpoint = torch.load(model_path, map_location=device) |
| vocab_info = checkpoint['vocab_info'] |
| |
| |
| model = Transformer( |
| d_model=384, |
| ffn_hidden=1536, |
| num_heads=6, |
| drop_prob=0.1, |
| num_layers=4, |
| max_sequence_length=75, |
| src_vocab_size=vocab_info['source_vocab_size'], |
| tgt_vocab_size=vocab_info['target_vocab_size'] |
| ) |
| |
| |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.to(device) |
| model.eval() |
| |
| print(f"Model loaded successfully!") |
| print(f"Training loss: {checkpoint['train_loss']:.4f}") |
| print(f"Validation loss: {checkpoint['val_loss']:.4f}") |
| |
| |
| print(f"Loading tokenizers...") |
| src_tokenizer = Tokenizer.from_file(src_tokenizer_path) |
| tgt_tokenizer = Tokenizer.from_file(tgt_tokenizer_path) |
| |
| return model, src_tokenizer, tgt_tokenizer, vocab_info, device |
|
|
|
|
| def translate_sentence(sentence, model, src_tokenizer, tgt_tokenizer, vocab_info, device): |
| """ |
| Translate a single English sentence to Kannada |
| |
| Args: |
| sentence: English sentence to translate |
| model: Trained transformer model |
| src_tokenizer: Source tokenizer |
| tgt_tokenizer: Target tokenizer |
| vocab_info: Vocabulary information |
| device: Computation device |
| |
| Returns: |
| Translated Kannada sentence |
| """ |
| translation = greedy_decode( |
| model=model, |
| src_sentence=sentence, |
| source_tokenizer=src_tokenizer, |
| target_tokenizer=tgt_tokenizer, |
| vocab_info=vocab_info, |
| device=device, |
| max_length=75 |
| ) |
| return translation |
|
|
|
|
| def batch_translate(sentences, model, src_tokenizer, tgt_tokenizer, vocab_info, device): |
| """ |
| Translate multiple sentences |
| |
| Args: |
| sentences: List of English sentences |
| model: Trained transformer model |
| src_tokenizer: Source tokenizer |
| tgt_tokenizer: Target tokenizer |
| vocab_info: Vocabulary information |
| device: Computation device |
| |
| Returns: |
| List of translated Kannada sentences |
| """ |
| translations = [] |
| for sentence in sentences: |
| translation = translate_sentence(sentence, model, src_tokenizer, |
| tgt_tokenizer, vocab_info, device) |
| translations.append(translation) |
| return translations |
|
|
|
|
| def run_test_suite(): |
| """ |
| Run a comprehensive test suite with various sentence types |
| """ |
| |
| model, src_tokenizer, tgt_tokenizer, vocab_info, device = load_model_and_tokenizers() |
| |
| |
| test_sentences = [ |
| "Hello, how are you?", |
| "What is your name?", |
| "I am going to school.", |
| "The weather is nice today.", |
| "Can you help me?", |
| "Thank you very much.", |
| "Good morning!", |
| "Where is the hospital?", |
| "I love learning new languages.", |
| "This is a beautiful place." |
| ] |
| |
| print("\n" + "="*80) |
| print("TRANSLATION TEST RESULTS") |
| print("="*80 + "\n") |
| |
| for i, sentence in enumerate(test_sentences, 1): |
| translation = translate_sentence(sentence, model, src_tokenizer, |
| tgt_tokenizer, vocab_info, device) |
| print(f"Test {i}:") |
| print(f" English: {sentence}") |
| print(f" Kannada: {translation}") |
| print() |
| |
| print("="*80) |
|
|
|
|
| def interactive_mode(): |
| """ |
| Interactive translation mode - translate user input in real-time |
| """ |
| |
| model, src_tokenizer, tgt_tokenizer, vocab_info, device = load_model_and_tokenizers() |
| |
| print("\n" + "="*80) |
| print("INTERACTIVE TRANSLATION MODE") |
| print("English to Kannada Translation") |
| print("Type 'quit' or 'exit' to stop") |
| print("="*80 + "\n") |
| |
| while True: |
| sentence = input("Enter English sentence: ").strip() |
| |
| if sentence.lower() in ['quit', 'exit']: |
| print("Goodbye!") |
| break |
| |
| if not sentence: |
| continue |
| |
| translation = translate_sentence(sentence, model, src_tokenizer, |
| tgt_tokenizer, vocab_info, device) |
| print(f"Kannada: {translation}\n") |
|
|
|
|
| def benchmark_performance(num_samples=100): |
| """ |
| Benchmark model performance and speed |
| |
| Args: |
| num_samples: Number of translations to run for benchmarking |
| """ |
| import time |
| |
| |
| model, src_tokenizer, tgt_tokenizer, vocab_info, device = load_model_and_tokenizers() |
| |
| test_sentence = "This is a test sentence for benchmarking." |
| |
| print("\n" + "="*80) |
| print("PERFORMANCE BENCHMARK") |
| print("="*80 + "\n") |
| |
| |
| for _ in range(5): |
| _ = translate_sentence(test_sentence, model, src_tokenizer, |
| tgt_tokenizer, vocab_info, device) |
| |
| |
| start_time = time.time() |
| for _ in range(num_samples): |
| _ = translate_sentence(test_sentence, model, src_tokenizer, |
| tgt_tokenizer, vocab_info, device) |
| end_time = time.time() |
| |
| total_time = end_time - start_time |
| avg_time = total_time / num_samples |
| |
| print(f"Total translations: {num_samples}") |
| print(f"Total time: {total_time:.2f} seconds") |
| print(f"Average time per translation: {avg_time*1000:.2f} ms") |
| print(f"Translations per second: {num_samples/total_time:.2f}") |
| print(f"Device: {device}") |
| print("="*80) |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
| |
| if len(sys.argv) > 1: |
| mode = sys.argv[1].lower() |
| |
| if mode == 'interactive': |
| interactive_mode() |
| elif mode == 'benchmark': |
| benchmark_performance() |
| elif mode == 'test': |
| run_test_suite() |
| else: |
| print("Usage: python test_model.py [test|interactive|benchmark]") |
| print(" test - Run predefined test suite") |
| print(" interactive - Interactive translation mode") |
| print(" benchmark - Performance benchmarking") |
| else: |
| |
| run_test_suite() |
|
|