| |
| """ |
| Utility script to convert Helsinki NLP Opus MT models to CTranslate2 format. |
| This script handles the special case of Dravidian languages. |
| """ |
|
|
| import argparse |
| import logging |
| import os |
| import sys |
|
|
| import torch |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| COMMON_LANGUAGE_PAIRS = [ |
| ("en", "es"), |
| ("en", "fr"), |
| ("en", "de"), |
| ("en", "ru"), |
| ("en", "zh"), |
| ("en", "ar"), |
| ("en", "hi"), |
| ("en", "dra"), |
| ("es", "en"), |
| ("fr", "en"), |
| ("de", "en"), |
| ("ru", "en"), |
| ("zh", "en"), |
| ("ar", "en"), |
| ("hi", "en"), |
| ] |
|
|
| QUANTIZATION_TYPES = { |
| "int8": "8-bit integer quantization (best for CPU)", |
| "int16": "16-bit integer quantization", |
| "float16": "16-bit floating point (best for modern GPUs)", |
| "float8": "8-bit floating point (experimental)", |
| "auto": "Automatic selection based on device", |
| } |
|
|
| def get_device() -> str: |
| """Get the best available device for model inference.""" |
| if torch.cuda.is_available(): |
| return "cuda" |
| else: |
| return "cpu" |
|
|
| def get_auto_quantization(device: str) -> str: |
| """Get the appropriate quantization based on device.""" |
| if device == "cuda": |
| return "float16" |
| else: |
| return "int8" |
|
|
| def get_huggingface_model_name(src_lang: str, tgt_lang: str) -> str: |
| """Get the appropriate HuggingFace model name for the language pair.""" |
| return f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}" |
|
|
| def convert_model( |
| src_lang: str, |
| tgt_lang: str, |
| output_dir: str, |
| quantization: str = "auto", |
| force: bool = False |
| ) -> bool: |
| """ |
| Convert a Helsinki NLP model to CTranslate2 format. |
| |
| Args: |
| src_lang: Source language code |
| tgt_lang: Target language code |
| output_dir: Output directory path |
| quantization: Quantization type |
| force: Whether to force conversion if model exists |
| |
| Returns: |
| bool: Success status |
| """ |
| try: |
|
|
| model_key = f"{src_lang}-{tgt_lang}" |
| model_dir = os.path.join(output_dir, f"ct2_{model_key}") |
|
|
| if os.path.exists(model_dir) and os.path.isdir(model_dir) and not force: |
| logger.info(f"Model {model_key} already exists at {model_dir}. Use --force to overwrite.") |
| return True |
|
|
| huggingface_model = get_huggingface_model_name(src_lang, tgt_lang) |
| logger.info(f"Converting model {huggingface_model} to CTranslate2 format") |
|
|
| try: |
| import ctranslate2 |
| except ImportError: |
| logger.error("CTranslate2 is not installed. Please install with 'pip install ctranslate2'") |
| return False |
|
|
| device = get_device() |
| if quantization == "auto": |
| quantization = get_auto_quantization(device) |
| |
| logger.info(f"Using {quantization} quantization for {device} device") |
| |
| try: |
| from ctranslate2.converters import TransformersConverter |
|
|
| converter = TransformersConverter(huggingface_model) |
|
|
| converter.convert( |
| model_dir, |
| quantization=quantization, |
| force=True |
| ) |
| |
| logger.info(f"Successfully converted {huggingface_model} to CTranslate2 format at {model_dir}") |
| return True |
| |
| except ImportError: |
| logger.warning("Could not import TransformersConverter, falling back to command line") |
|
|
| import subprocess |
| cmd = [ |
| "ct2-transformers-converter", |
| "--model", huggingface_model, |
| "--output_dir", model_dir, |
| "--quantization", quantization, |
| "--force" |
| ] |
|
|
| logger.info(f"Running command: {' '.join(cmd)}") |
| result = subprocess.run(cmd, capture_output=True, text=True) |
| |
| if result.returncode == 0: |
| logger.info(f"Successfully converted model using shell command") |
| return True |
| else: |
| logger.error(f"Error in shell command: {result.stderr}") |
| return False |
| |
| except Exception as e: |
| logger.error(f"Error converting model {src_lang}-{tgt_lang}: {str(e)}") |
| return False |
|
|
| def main(): |
| """Main entry point for the script.""" |
| parser = argparse.ArgumentParser( |
| description="Convert Helsinki NLP Opus MT models to CTranslate2 format" |
| ) |
| |
| parser.add_argument( |
| "--src", |
| type=str, |
| help="Source language code (e.g., 'en')" |
| ) |
| parser.add_argument( |
| "--tgt", |
| type=str, |
| help="Target language code (e.g., 'es', 'fr', 'dra')" |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default=".cache/ct2_models", |
| help="Output directory for converted models" |
| ) |
| parser.add_argument( |
| "--quantization", |
| type=str, |
| choices=list(QUANTIZATION_TYPES.keys()), |
| default="auto", |
| help="Quantization type to use" |
| ) |
| parser.add_argument( |
| "--force", |
| action="store_true", |
| help="Force conversion even if model exists" |
| ) |
| parser.add_argument( |
| "--all", |
| action="store_true", |
| help="Convert all common language pairs" |
| ) |
| parser.add_argument( |
| "--list", |
| action="store_true", |
| help="List all common language pairs" |
| ) |
| |
| args = parser.parse_args() |
| |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| if args.list: |
| print("\nCommon language pairs:") |
| for src, tgt in COMMON_LANGUAGE_PAIRS: |
| print(f" {src}-{tgt}") |
| print("\nQuantization types:") |
| for q_type, desc in QUANTIZATION_TYPES.items(): |
| print(f" {q_type}: {desc}") |
| return 0 |
|
|
| if args.all: |
| results = {} |
| for src_lang, tgt_lang in COMMON_LANGUAGE_PAIRS: |
| model_key = f"{src_lang}-{tgt_lang}" |
| logger.info(f"Processing model pair: {model_key}") |
| |
| success = convert_model( |
| src_lang=src_lang, |
| tgt_lang=tgt_lang, |
| output_dir=args.output_dir, |
| quantization=args.quantization, |
| force=args.force |
| ) |
| |
| results[model_key] = success |
| |
| logger.info("\n=== Conversion Summary ===") |
| success_count = sum(1 for success in results.values() if success) |
| logger.info(f"Successfully converted {success_count} of {len(results)} models") |
| |
| for model_key, success in results.items(): |
| status = "✓" if success else "✗" |
| logger.info(f"{status} {model_key}") |
| |
| return 0 if all(results.values()) else 1 |
|
|
| if not args.src or not args.tgt: |
| parser.error("--src and --tgt are required unless --all or --list is specified") |
| |
| success = convert_model( |
| src_lang=args.src, |
| tgt_lang=args.tgt, |
| output_dir=args.output_dir, |
| quantization=args.quantization, |
| force=args.force |
| ) |
| |
| return 0 if success else 1 |
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |