engtokantranslation / test_model.py
yashwan2003's picture
Upload folder using huggingface_hub
e2dd3a8 verified
"""
Test script for English to Kannada Machine Translation Model
This script loads the trained model and performs inference testing
"""
import torch
import json
from tokenizers import Tokenizer
from main import Transformer, greedy_decode
# Special tokens
START_TOKEN = '<START>'
END_TOKEN = '<END>'
PADDING_TOKEN = '<PAD>'
UNKNOWN_TOKEN = '<UNK>'
def load_model_and_tokenizers(model_path='best_model.pt',
src_tokenizer_path='source_tokenizer.json',
tgt_tokenizer_path='target_tokenizer.json'):
"""
Load the trained model and tokenizers
Args:
model_path: Path to saved model checkpoint
src_tokenizer_path: Path to source (English) tokenizer
tgt_tokenizer_path: Path to target (Kannada) tokenizer
Returns:
model, src_tokenizer, tgt_tokenizer, vocab_info, device
"""
# Load device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load checkpoint
print(f"Loading model from {model_path}...")
checkpoint = torch.load(model_path, map_location=device)
vocab_info = checkpoint['vocab_info']
# Initialize model
model = Transformer(
d_model=384,
ffn_hidden=1536,
num_heads=6,
drop_prob=0.1,
num_layers=4,
max_sequence_length=75,
src_vocab_size=vocab_info['source_vocab_size'],
tgt_vocab_size=vocab_info['target_vocab_size']
)
# Load model weights
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
print(f"Model loaded successfully!")
print(f"Training loss: {checkpoint['train_loss']:.4f}")
print(f"Validation loss: {checkpoint['val_loss']:.4f}")
# Load tokenizers
print(f"Loading tokenizers...")
src_tokenizer = Tokenizer.from_file(src_tokenizer_path)
tgt_tokenizer = Tokenizer.from_file(tgt_tokenizer_path)
return model, src_tokenizer, tgt_tokenizer, vocab_info, device
def translate_sentence(sentence, model, src_tokenizer, tgt_tokenizer, vocab_info, device):
"""
Translate a single English sentence to Kannada
Args:
sentence: English sentence to translate
model: Trained transformer model
src_tokenizer: Source tokenizer
tgt_tokenizer: Target tokenizer
vocab_info: Vocabulary information
device: Computation device
Returns:
Translated Kannada sentence
"""
translation = greedy_decode(
model=model,
src_sentence=sentence,
source_tokenizer=src_tokenizer,
target_tokenizer=tgt_tokenizer,
vocab_info=vocab_info,
device=device,
max_length=75
)
return translation
def batch_translate(sentences, model, src_tokenizer, tgt_tokenizer, vocab_info, device):
"""
Translate multiple sentences
Args:
sentences: List of English sentences
model: Trained transformer model
src_tokenizer: Source tokenizer
tgt_tokenizer: Target tokenizer
vocab_info: Vocabulary information
device: Computation device
Returns:
List of translated Kannada sentences
"""
translations = []
for sentence in sentences:
translation = translate_sentence(sentence, model, src_tokenizer,
tgt_tokenizer, vocab_info, device)
translations.append(translation)
return translations
def run_test_suite():
"""
Run a comprehensive test suite with various sentence types
"""
# Load model and tokenizers
model, src_tokenizer, tgt_tokenizer, vocab_info, device = load_model_and_tokenizers()
# Test sentences
test_sentences = [
"Hello, how are you?",
"What is your name?",
"I am going to school.",
"The weather is nice today.",
"Can you help me?",
"Thank you very much.",
"Good morning!",
"Where is the hospital?",
"I love learning new languages.",
"This is a beautiful place."
]
print("\n" + "="*80)
print("TRANSLATION TEST RESULTS")
print("="*80 + "\n")
for i, sentence in enumerate(test_sentences, 1):
translation = translate_sentence(sentence, model, src_tokenizer,
tgt_tokenizer, vocab_info, device)
print(f"Test {i}:")
print(f" English: {sentence}")
print(f" Kannada: {translation}")
print()
print("="*80)
def interactive_mode():
"""
Interactive translation mode - translate user input in real-time
"""
# Load model and tokenizers
model, src_tokenizer, tgt_tokenizer, vocab_info, device = load_model_and_tokenizers()
print("\n" + "="*80)
print("INTERACTIVE TRANSLATION MODE")
print("English to Kannada Translation")
print("Type 'quit' or 'exit' to stop")
print("="*80 + "\n")
while True:
sentence = input("Enter English sentence: ").strip()
if sentence.lower() in ['quit', 'exit']:
print("Goodbye!")
break
if not sentence:
continue
translation = translate_sentence(sentence, model, src_tokenizer,
tgt_tokenizer, vocab_info, device)
print(f"Kannada: {translation}\n")
def benchmark_performance(num_samples=100):
"""
Benchmark model performance and speed
Args:
num_samples: Number of translations to run for benchmarking
"""
import time
# Load model and tokenizers
model, src_tokenizer, tgt_tokenizer, vocab_info, device = load_model_and_tokenizers()
test_sentence = "This is a test sentence for benchmarking."
print("\n" + "="*80)
print("PERFORMANCE BENCHMARK")
print("="*80 + "\n")
# Warmup
for _ in range(5):
_ = translate_sentence(test_sentence, model, src_tokenizer,
tgt_tokenizer, vocab_info, device)
# Benchmark
start_time = time.time()
for _ in range(num_samples):
_ = translate_sentence(test_sentence, model, src_tokenizer,
tgt_tokenizer, vocab_info, device)
end_time = time.time()
total_time = end_time - start_time
avg_time = total_time / num_samples
print(f"Total translations: {num_samples}")
print(f"Total time: {total_time:.2f} seconds")
print(f"Average time per translation: {avg_time*1000:.2f} ms")
print(f"Translations per second: {num_samples/total_time:.2f}")
print(f"Device: {device}")
print("="*80)
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
mode = sys.argv[1].lower()
if mode == 'interactive':
interactive_mode()
elif mode == 'benchmark':
benchmark_performance()
elif mode == 'test':
run_test_suite()
else:
print("Usage: python test_model.py [test|interactive|benchmark]")
print(" test - Run predefined test suite")
print(" interactive - Interactive translation mode")
print(" benchmark - Performance benchmarking")
else:
# Default: run test suite
run_test_suite()