"""NLLB translation provider implementation.""" import logging from typing import Dict, List, Optional from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from ..base.translation_provider_base import TranslationProviderBase from ...domain.exceptions import TranslationFailedException logger = logging.getLogger(__name__) class NLLBTranslationProvider(TranslationProviderBase): """NLLB-200-3.3B translation provider implementation.""" # NLLB language code mappings LANGUAGE_MAPPINGS = { 'en': 'eng_Latn', 'zh': 'zho_Hans', 'zh-cn': 'zho_Hans', 'zh-tw': 'zho_Hant' } def __init__(self, model_name: str = "facebook/nllb-200-3.3B", max_chunk_length: int = 1000): """ Initialize NLLB translation provider. Args: model_name: The NLLB model name to use max_chunk_length: Maximum length for text chunks """ # Build supported languages mapping supported_languages = {} for lang_code in self.LANGUAGE_MAPPINGS.keys(): # For simplicity, assume all languages can translate to all other languages # In practice, you might want to be more specific about supported pairs supported_languages[lang_code] = [ target for target in self.LANGUAGE_MAPPINGS.keys() if target != lang_code ] super().__init__( provider_name="NLLB-200-3.3B", supported_languages=supported_languages ) self.model_name = model_name self.max_chunk_length = max_chunk_length self._tokenizer: Optional[AutoTokenizer] = None self._model: Optional[AutoModelForSeq2SeqLM] = None self._model_loaded = False def _translate_chunk(self, text: str, source_language: str, target_language: str) -> str: """ Translate a single chunk of text using NLLB model. Args: text: The text chunk to translate source_language: Source language code target_language: Target language code Returns: str: The translated text chunk """ try: # Ensure model is loaded self._ensure_model_loaded() # Map language codes to NLLB format source_nllb = self._map_language_code(source_language) target_nllb = self._map_language_code(target_language) logger.info(f"Translating chunk from {source_nllb} to {target_nllb}") # Tokenize with source language specification inputs = self._tokenizer( text, return_tensors="pt", max_length=1024, truncation=True ) # Generate translation with target language specification outputs = self._model.generate( **inputs, forced_bos_token_id=self._tokenizer.convert_tokens_to_ids(target_nllb), max_new_tokens=1024, num_beams=4, early_stopping=True ) # Decode the translation translated = self._tokenizer.decode(outputs[0], skip_special_tokens=True) # Post-process the translation translated = self._postprocess_text(translated) logger.info(f"Chunk translation completed: {len(text)} -> {len(translated)} chars") return translated except Exception as e: self._handle_provider_error(e, "chunk translation") def _ensure_model_loaded(self) -> None: """Ensure the NLLB model and tokenizer are loaded.""" if self._model_loaded: return try: logger.info(f"Loading NLLB model: {self.model_name}") # Load tokenizer self._tokenizer = AutoTokenizer.from_pretrained( self.model_name, src_lang="eng_Latn" # Default source language ) # Load model self._model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) self._model_loaded = True logger.info("NLLB model loaded successfully") except Exception as e: logger.error(f"Failed to load NLLB model: {str(e)}") raise TranslationFailedException(f"Failed to load NLLB model: {str(e)}") from e def _map_language_code(self, language_code: str) -> str: """ Map standard language code to NLLB format. Args: language_code: Standard language code (e.g., 'en', 'zh') Returns: str: NLLB language code (e.g., 'eng_Latn', 'zho_Hans') """ # Normalize language code to lowercase normalized_code = language_code.lower() # Check direct mapping if normalized_code in self.LANGUAGE_MAPPINGS: return self.LANGUAGE_MAPPINGS[normalized_code] # Handle common variations if normalized_code.startswith('zh'): if 'tw' in normalized_code or 'hant' in normalized_code or 'traditional' in normalized_code: return 'zho_Hant' else: return 'zho_Hans' # Default fallback for unknown codes logger.warning(f"Unknown language code: {language_code}, defaulting to English") return 'eng_Latn' def is_available(self) -> bool: """ Check if the NLLB translation provider is available. Returns: bool: True if provider is available, False otherwise """ try: # Try to import required dependencies import transformers import torch # Check if we can load the tokenizer (lightweight check) if not self._model_loaded: try: test_tokenizer = AutoTokenizer.from_pretrained( self.model_name, src_lang="eng_Latn" ) return True except Exception as e: logger.warning(f"NLLB model not available: {str(e)}") return False else: return True except ImportError as e: logger.warning(f"NLLB dependencies not available: {str(e)}") return False def get_supported_languages(self) -> Dict[str, List[str]]: """ Get supported language pairs for NLLB provider. Returns: dict: Mapping of source languages to supported target languages """ return self.supported_languages.copy() def get_model_info(self) -> Dict[str, str]: """ Get information about the loaded model. Returns: dict: Model information """ return { 'provider': self.provider_name, 'model_name': self.model_name, 'model_loaded': str(self._model_loaded), 'supported_language_count': str(len(self.LANGUAGE_MAPPINGS)), 'max_chunk_length': str(self.max_chunk_length) } def set_model_name(self, model_name: str) -> None: """ Set a different NLLB model name. Args: model_name: The new model name to use """ if model_name != self.model_name: self.model_name = model_name self._model_loaded = False self._tokenizer = None self._model = None logger.info(f"Model name changed to: {model_name}") def clear_model_cache(self) -> None: """Clear the loaded model from memory.""" if self._model_loaded: self._tokenizer = None self._model = None self._model_loaded = False logger.info("NLLB model cache cleared")