lt_space / app /models /ct2_model_converter.py
Arsive2's picture
Updated ct translate
6a6828e
#!/usr/bin/env python
"""
Utility script to convert Helsinki NLP Opus MT models to CTranslate2 format.
This script handles the special case of Dravidian languages.
"""
import argparse
import logging
import os
import sys
import torch
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Common language pairs
COMMON_LANGUAGE_PAIRS = [
("en", "es"), # English to Spanish
("en", "fr"), # English to French
("en", "de"), # English to German
("en", "ru"), # English to Russian
("en", "zh"), # English to Chinese
("en", "ar"), # English to Arabic
("en", "hi"), # English to Hindi
("en", "dra"), # English to Dravidian languages
("es", "en"), # Spanish to English
("fr", "en"), # French to English
("de", "en"), # German to English
("ru", "en"), # Russian to English
("zh", "en"), # Chinese to English
("ar", "en"), # Arabic to English
("hi", "en"), # Hindi to English
]
QUANTIZATION_TYPES = {
"int8": "8-bit integer quantization (best for CPU)",
"int16": "16-bit integer quantization",
"float16": "16-bit floating point (best for modern GPUs)",
"float8": "8-bit floating point (experimental)",
"auto": "Automatic selection based on device",
}
def get_device() -> str:
"""Get the best available device for model inference."""
if torch.cuda.is_available():
return "cuda"
else:
return "cpu"
def get_auto_quantization(device: str) -> str:
"""Get the appropriate quantization based on device."""
if device == "cuda":
return "float16"
else:
return "int8"
def get_huggingface_model_name(src_lang: str, tgt_lang: str) -> str:
"""Get the appropriate HuggingFace model name for the language pair."""
return f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
def convert_model(
src_lang: str,
tgt_lang: str,
output_dir: str,
quantization: str = "auto",
force: bool = False
) -> bool:
"""
Convert a Helsinki NLP model to CTranslate2 format.
Args:
src_lang: Source language code
tgt_lang: Target language code
output_dir: Output directory path
quantization: Quantization type
force: Whether to force conversion if model exists
Returns:
bool: Success status
"""
try:
model_key = f"{src_lang}-{tgt_lang}"
model_dir = os.path.join(output_dir, f"ct2_{model_key}")
if os.path.exists(model_dir) and os.path.isdir(model_dir) and not force:
logger.info(f"Model {model_key} already exists at {model_dir}. Use --force to overwrite.")
return True
huggingface_model = get_huggingface_model_name(src_lang, tgt_lang)
logger.info(f"Converting model {huggingface_model} to CTranslate2 format")
try:
import ctranslate2
except ImportError:
logger.error("CTranslate2 is not installed. Please install with 'pip install ctranslate2'")
return False
device = get_device()
if quantization == "auto":
quantization = get_auto_quantization(device)
logger.info(f"Using {quantization} quantization for {device} device")
try:
from ctranslate2.converters import TransformersConverter
converter = TransformersConverter(huggingface_model)
converter.convert(
model_dir,
quantization=quantization,
force=True
)
logger.info(f"Successfully converted {huggingface_model} to CTranslate2 format at {model_dir}")
return True
except ImportError:
logger.warning("Could not import TransformersConverter, falling back to command line")
import subprocess
cmd = [
"ct2-transformers-converter",
"--model", huggingface_model,
"--output_dir", model_dir,
"--quantization", quantization,
"--force"
]
logger.info(f"Running command: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(f"Successfully converted model using shell command")
return True
else:
logger.error(f"Error in shell command: {result.stderr}")
return False
except Exception as e:
logger.error(f"Error converting model {src_lang}-{tgt_lang}: {str(e)}")
return False
def main():
"""Main entry point for the script."""
parser = argparse.ArgumentParser(
description="Convert Helsinki NLP Opus MT models to CTranslate2 format"
)
parser.add_argument(
"--src",
type=str,
help="Source language code (e.g., 'en')"
)
parser.add_argument(
"--tgt",
type=str,
help="Target language code (e.g., 'es', 'fr', 'dra')"
)
parser.add_argument(
"--output-dir",
type=str,
default=".cache/ct2_models",
help="Output directory for converted models"
)
parser.add_argument(
"--quantization",
type=str,
choices=list(QUANTIZATION_TYPES.keys()),
default="auto",
help="Quantization type to use"
)
parser.add_argument(
"--force",
action="store_true",
help="Force conversion even if model exists"
)
parser.add_argument(
"--all",
action="store_true",
help="Convert all common language pairs"
)
parser.add_argument(
"--list",
action="store_true",
help="List all common language pairs"
)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
if args.list:
print("\nCommon language pairs:")
for src, tgt in COMMON_LANGUAGE_PAIRS:
print(f" {src}-{tgt}")
print("\nQuantization types:")
for q_type, desc in QUANTIZATION_TYPES.items():
print(f" {q_type}: {desc}")
return 0
if args.all:
results = {}
for src_lang, tgt_lang in COMMON_LANGUAGE_PAIRS:
model_key = f"{src_lang}-{tgt_lang}"
logger.info(f"Processing model pair: {model_key}")
success = convert_model(
src_lang=src_lang,
tgt_lang=tgt_lang,
output_dir=args.output_dir,
quantization=args.quantization,
force=args.force
)
results[model_key] = success
logger.info("\n=== Conversion Summary ===")
success_count = sum(1 for success in results.values() if success)
logger.info(f"Successfully converted {success_count} of {len(results)} models")
for model_key, success in results.items():
status = "✓" if success else "✗"
logger.info(f"{status} {model_key}")
return 0 if all(results.values()) else 1
if not args.src or not args.tgt:
parser.error("--src and --tgt are required unless --all or --list is specified")
success = convert_model(
src_lang=args.src,
tgt_lang=args.tgt,
output_dir=args.output_dir,
quantization=args.quantization,
force=args.force
)
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())