|
|
|
|
|
""" |
|
|
Benchmark script to compare performance between standard Transformers |
|
|
and CTranslate2 optimized translation models. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import tqdm |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
try: |
|
|
from app.models.translation_model import TranslationModel |
|
|
from app.models.translation_model_ct2 import TranslationModelCT2 |
|
|
except ImportError: |
|
|
logger.error("Could not import translation models. Make sure you're running this script from the project root.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
TEST_SENTENCES = { |
|
|
"en-es": [ |
|
|
"Hello, how are you today?", |
|
|
"I would like to book a flight to Madrid for next week.", |
|
|
"The quick brown fox jumps over the lazy dog.", |
|
|
"Artificial intelligence is transforming the way we live and work.", |
|
|
"Please contact our customer service if you have any questions." |
|
|
], |
|
|
"en-fr": [ |
|
|
"Hello, how are you today?", |
|
|
"I would like to book a flight to Paris for next week.", |
|
|
"The quick brown fox jumps over the lazy dog.", |
|
|
"Artificial intelligence is transforming the way we live and work.", |
|
|
"Please contact our customer service if you have any questions." |
|
|
], |
|
|
"en-de": [ |
|
|
"Hello, how are you today?", |
|
|
"I would like to book a flight to Berlin for next week.", |
|
|
"The quick brown fox jumps over the lazy dog.", |
|
|
"Artificial intelligence is transforming the way we live and work.", |
|
|
"Please contact our customer service if you have any questions." |
|
|
], |
|
|
"en-dra": [ |
|
|
"Hello, how are you today?", |
|
|
"I would like to book a flight to Chennai for next week.", |
|
|
"The quick brown fox jumps over the lazy dog.", |
|
|
"Artificial intelligence is transforming the way we live and work.", |
|
|
"Please contact our customer service if you have any questions." |
|
|
] |
|
|
} |
|
|
|
|
|
def benchmark_standard_model( |
|
|
src_lang: str, |
|
|
tgt_lang: str, |
|
|
sentences: List[str], |
|
|
num_runs: int = 5, |
|
|
warm_up: int = 2 |
|
|
) -> Dict: |
|
|
"""Benchmark the standard Transformers model.""" |
|
|
logger.info(f"Benchmarking standard Transformers model for {src_lang}-{tgt_lang}") |
|
|
|
|
|
|
|
|
model = TranslationModel() |
|
|
|
|
|
|
|
|
logger.info(f"Performing {warm_up} warm-up runs...") |
|
|
for _ in range(warm_up): |
|
|
for sentence in sentences[:2]: |
|
|
model.translate(sentence, src_lang, tgt_lang) |
|
|
|
|
|
|
|
|
logger.info(f"Performing {num_runs} benchmark runs...") |
|
|
times = [] |
|
|
translations = [] |
|
|
|
|
|
for run in range(num_runs): |
|
|
run_times = [] |
|
|
run_translations = [] |
|
|
|
|
|
for sentence in tqdm.tqdm(sentences, desc=f"Run {run+1}/{num_runs}"): |
|
|
start_time = time.time() |
|
|
translation = model.translate(sentence, src_lang, tgt_lang) |
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
run_times.append(elapsed_time) |
|
|
run_translations.append(translation) |
|
|
|
|
|
times.append(run_times) |
|
|
|
|
|
|
|
|
if run == 0: |
|
|
translations = run_translations |
|
|
|
|
|
|
|
|
all_times = np.array(times).flatten() |
|
|
stats = { |
|
|
"mean_time": float(np.mean(all_times)), |
|
|
"median_time": float(np.median(all_times)), |
|
|
"std_dev": float(np.std(all_times)), |
|
|
"min_time": float(np.min(all_times)), |
|
|
"max_time": float(np.max(all_times)), |
|
|
"total_time": float(np.sum(all_times)), |
|
|
"num_sentences": len(sentences) * num_runs, |
|
|
"translations": translations |
|
|
} |
|
|
|
|
|
return stats |
|
|
|
|
|
def benchmark_ct2_model( |
|
|
src_lang: str, |
|
|
tgt_lang: str, |
|
|
sentences: List[str], |
|
|
num_runs: int = 5, |
|
|
warm_up: int = 2 |
|
|
) -> Dict: |
|
|
"""Benchmark the CTranslate2 optimized model.""" |
|
|
logger.info(f"Benchmarking CTranslate2 model for {src_lang}-{tgt_lang}") |
|
|
|
|
|
|
|
|
model = TranslationModelCT2() |
|
|
|
|
|
|
|
|
logger.info(f"Performing {warm_up} warm-up runs...") |
|
|
for _ in range(warm_up): |
|
|
for sentence in sentences[:2]: |
|
|
model.translate(sentence, src_lang, tgt_lang) |
|
|
|
|
|
|
|
|
logger.info(f"Performing {num_runs} benchmark runs...") |
|
|
times = [] |
|
|
translations = [] |
|
|
|
|
|
for run in range(num_runs): |
|
|
run_times = [] |
|
|
run_translations = [] |
|
|
|
|
|
for sentence in tqdm.tqdm(sentences, desc=f"Run {run+1}/{num_runs}"): |
|
|
start_time = time.time() |
|
|
translation = model.translate(sentence, src_lang, tgt_lang) |
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
run_times.append(elapsed_time) |
|
|
run_translations.append(translation) |
|
|
|
|
|
times.append(run_times) |
|
|
|
|
|
|
|
|
if run == 0: |
|
|
translations = run_translations |
|
|
|
|
|
|
|
|
all_times = np.array(times).flatten() |
|
|
stats = { |
|
|
"mean_time": float(np.mean(all_times)), |
|
|
"median_time": float(np.median(all_times)), |
|
|
"std_dev": float(np.std(all_times)), |
|
|
"min_time": float(np.min(all_times)), |
|
|
"max_time": float(np.max(all_times)), |
|
|
"total_time": float(np.sum(all_times)), |
|
|
"num_sentences": len(sentences) * num_runs, |
|
|
"translations": translations |
|
|
} |
|
|
|
|
|
return stats |
|
|
|
|
|
def benchmark_batch( |
|
|
src_lang: str, |
|
|
tgt_lang: str, |
|
|
sentences: List[str], |
|
|
num_runs: int = 5, |
|
|
warm_up: int = 2 |
|
|
) -> Dict: |
|
|
"""Benchmark batch translation with CTranslate2.""" |
|
|
logger.info(f"Benchmarking CTranslate2 batch translation for {src_lang}-{tgt_lang}") |
|
|
|
|
|
|
|
|
model = TranslationModelCT2() |
|
|
|
|
|
|
|
|
logger.info(f"Performing {warm_up} warm-up runs...") |
|
|
for _ in range(warm_up): |
|
|
model.translate_batch(sentences[:2], src_lang, tgt_lang) |
|
|
|
|
|
|
|
|
logger.info(f"Performing {num_runs} benchmark runs...") |
|
|
times = [] |
|
|
translations = [] |
|
|
|
|
|
for run in range(num_runs): |
|
|
start_time = time.time() |
|
|
batch_translations = model.translate_batch(sentences, src_lang, tgt_lang) |
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
times.append(elapsed_time) |
|
|
|
|
|
|
|
|
if run == 0: |
|
|
translations = batch_translations |
|
|
|
|
|
|
|
|
stats = { |
|
|
"mean_time": float(np.mean(times)), |
|
|
"median_time": float(np.median(times)), |
|
|
"std_dev": float(np.std(times)), |
|
|
"min_time": float(np.min(times)), |
|
|
"max_time": float(np.max(times)), |
|
|
"total_time": float(np.sum(times)), |
|
|
"num_sentences": len(sentences), |
|
|
"num_batches": num_runs, |
|
|
"translations": translations |
|
|
} |
|
|
|
|
|
return stats |
|
|
|
|
|
def run_benchmarks( |
|
|
lang_pairs: List[Tuple[str, str]], |
|
|
num_runs: int = 5, |
|
|
warm_up: int = 2, |
|
|
output_file: str = "benchmark_results.json" |
|
|
) -> Dict: |
|
|
"""Run benchmarks for specified language pairs.""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"Running benchmarks on {device}") |
|
|
|
|
|
results = { |
|
|
"device": device, |
|
|
"cuda_available": torch.cuda.is_available(), |
|
|
"cuda_version": torch.version.cuda if torch.cuda.is_available() else None, |
|
|
"num_runs": num_runs, |
|
|
"warm_up_runs": warm_up, |
|
|
"language_pairs": {} |
|
|
} |
|
|
|
|
|
for src_lang, tgt_lang in lang_pairs: |
|
|
model_key = f"{src_lang}-{tgt_lang}" |
|
|
|
|
|
if model_key not in TEST_SENTENCES: |
|
|
logger.warning(f"No test sentences available for {model_key}, skipping...") |
|
|
continue |
|
|
|
|
|
logger.info(f"Benchmarking {model_key}...") |
|
|
|
|
|
sentences = TEST_SENTENCES[model_key] |
|
|
|
|
|
|
|
|
standard_stats = benchmark_standard_model( |
|
|
src_lang, tgt_lang, sentences, num_runs, warm_up |
|
|
) |
|
|
|
|
|
|
|
|
ct2_stats = benchmark_ct2_model( |
|
|
src_lang, tgt_lang, sentences, num_runs, warm_up |
|
|
) |
|
|
|
|
|
|
|
|
batch_stats = benchmark_batch( |
|
|
src_lang, tgt_lang, sentences, num_runs, warm_up |
|
|
) |
|
|
|
|
|
|
|
|
speedup = standard_stats["mean_time"] / ct2_stats["mean_time"] |
|
|
batch_speedup = standard_stats["mean_time"] * len(sentences) / batch_stats["mean_time"] |
|
|
|
|
|
results["language_pairs"][model_key] = { |
|
|
"standard_model": standard_stats, |
|
|
"ct2_model": ct2_stats, |
|
|
"batch_translation": batch_stats, |
|
|
"speedup": float(speedup), |
|
|
"batch_speedup": float(batch_speedup) |
|
|
} |
|
|
|
|
|
|
|
|
logger.info(f"\nResults for {model_key}:") |
|
|
logger.info(f" Standard model average time: {standard_stats['mean_time']:.4f}s") |
|
|
logger.info(f" CTranslate2 model average time: {ct2_stats['mean_time']:.4f}s") |
|
|
logger.info(f" Batch translation average time: {batch_stats['mean_time']:.4f}s (for {len(sentences)} sentences)") |
|
|
logger.info(f" Speedup: {speedup:.2f}x") |
|
|
logger.info(f" Batch speedup: {batch_speedup:.2f}x") |
|
|
|
|
|
|
|
|
with open(output_file, "w") as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
logger.info(f"Benchmark results saved to {output_file}") |
|
|
|
|
|
return results |
|
|
|
|
|
def main(): |
|
|
"""Main entry point for the benchmark script.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Benchmark translation models performance" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--lang-pairs", |
|
|
type=str, |
|
|
nargs="+", |
|
|
default=["en-es", "en-fr", "en-de", "en-dra"], |
|
|
help="Language pairs to benchmark (e.g., 'en-es en-fr')" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--runs", |
|
|
type=int, |
|
|
default=5, |
|
|
help="Number of benchmark runs" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--warm-up", |
|
|
type=int, |
|
|
default=2, |
|
|
help="Number of warm-up runs" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
type=str, |
|
|
default="benchmark_results.json", |
|
|
help="Output file for benchmark results" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
lang_pairs = [] |
|
|
for pair in args.lang_pairs: |
|
|
if "-" in pair: |
|
|
src, tgt = pair.split("-") |
|
|
lang_pairs.append((src, tgt)) |
|
|
else: |
|
|
logger.warning(f"Invalid language pair format: {pair}, skipping...") |
|
|
|
|
|
if not lang_pairs: |
|
|
logger.error("No valid language pairs specified") |
|
|
return 1 |
|
|
|
|
|
|
|
|
run_benchmarks( |
|
|
lang_pairs=lang_pairs, |
|
|
num_runs=args.runs, |
|
|
warm_up=args.warm_up, |
|
|
output_file=args.output |
|
|
) |
|
|
|
|
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|