Michael Hu
feat: restrict supported languages to English and Chinese only
2418415
"""NLLB translation provider implementation."""
import logging
from typing import Dict, List, Optional
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from ..base.translation_provider_base import TranslationProviderBase
from ...domain.exceptions import TranslationFailedException
logger = logging.getLogger(__name__)
class NLLBTranslationProvider(TranslationProviderBase):
"""NLLB-200-3.3B translation provider implementation."""
# NLLB language code mappings
LANGUAGE_MAPPINGS = {
'en': 'eng_Latn',
'zh': 'zho_Hans',
'zh-cn': 'zho_Hans',
'zh-tw': 'zho_Hant'
}
def __init__(self, model_name: str = "facebook/nllb-200-3.3B", max_chunk_length: int = 1000):
"""
Initialize NLLB translation provider.
Args:
model_name: The NLLB model name to use
max_chunk_length: Maximum length for text chunks
"""
# Build supported languages mapping
supported_languages = {}
for lang_code in self.LANGUAGE_MAPPINGS.keys():
# For simplicity, assume all languages can translate to all other languages
# In practice, you might want to be more specific about supported pairs
supported_languages[lang_code] = [
target for target in self.LANGUAGE_MAPPINGS.keys()
if target != lang_code
]
super().__init__(
provider_name="NLLB-200-3.3B",
supported_languages=supported_languages
)
self.model_name = model_name
self.max_chunk_length = max_chunk_length
self._tokenizer: Optional[AutoTokenizer] = None
self._model: Optional[AutoModelForSeq2SeqLM] = None
self._model_loaded = False
def _translate_chunk(self, text: str, source_language: str, target_language: str) -> str:
"""
Translate a single chunk of text using NLLB model.
Args:
text: The text chunk to translate
source_language: Source language code
target_language: Target language code
Returns:
str: The translated text chunk
"""
try:
# Ensure model is loaded
self._ensure_model_loaded()
# Map language codes to NLLB format
source_nllb = self._map_language_code(source_language)
target_nllb = self._map_language_code(target_language)
logger.info(f"Translating chunk from {source_nllb} to {target_nllb}")
# Tokenize with source language specification
inputs = self._tokenizer(
text,
return_tensors="pt",
max_length=1024,
truncation=True
)
# Generate translation with target language specification
outputs = self._model.generate(
**inputs,
forced_bos_token_id=self._tokenizer.convert_tokens_to_ids(target_nllb),
max_new_tokens=1024,
num_beams=4,
early_stopping=True
)
# Decode the translation
translated = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
# Post-process the translation
translated = self._postprocess_text(translated)
logger.info(f"Chunk translation completed: {len(text)} -> {len(translated)} chars")
return translated
except Exception as e:
self._handle_provider_error(e, "chunk translation")
def _ensure_model_loaded(self) -> None:
"""Ensure the NLLB model and tokenizer are loaded."""
if self._model_loaded:
return
try:
logger.info(f"Loading NLLB model: {self.model_name}")
# Load tokenizer
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
src_lang="eng_Latn" # Default source language
)
# Load model
self._model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
self._model_loaded = True
logger.info("NLLB model loaded successfully")
except Exception as e:
logger.error(f"Failed to load NLLB model: {str(e)}")
raise TranslationFailedException(f"Failed to load NLLB model: {str(e)}") from e
def _map_language_code(self, language_code: str) -> str:
"""
Map standard language code to NLLB format.
Args:
language_code: Standard language code (e.g., 'en', 'zh')
Returns:
str: NLLB language code (e.g., 'eng_Latn', 'zho_Hans')
"""
# Normalize language code to lowercase
normalized_code = language_code.lower()
# Check direct mapping
if normalized_code in self.LANGUAGE_MAPPINGS:
return self.LANGUAGE_MAPPINGS[normalized_code]
# Handle common variations
if normalized_code.startswith('zh'):
if 'tw' in normalized_code or 'hant' in normalized_code or 'traditional' in normalized_code:
return 'zho_Hant'
else:
return 'zho_Hans'
# Default fallback for unknown codes
logger.warning(f"Unknown language code: {language_code}, defaulting to English")
return 'eng_Latn'
def is_available(self) -> bool:
"""
Check if the NLLB translation provider is available.
Returns:
bool: True if provider is available, False otherwise
"""
try:
# Try to import required dependencies
import transformers
import torch
# Check if we can load the tokenizer (lightweight check)
if not self._model_loaded:
try:
test_tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
src_lang="eng_Latn"
)
return True
except Exception as e:
logger.warning(f"NLLB model not available: {str(e)}")
return False
else:
return True
except ImportError as e:
logger.warning(f"NLLB dependencies not available: {str(e)}")
return False
def get_supported_languages(self) -> Dict[str, List[str]]:
"""
Get supported language pairs for NLLB provider.
Returns:
dict: Mapping of source languages to supported target languages
"""
return self.supported_languages.copy()
def get_model_info(self) -> Dict[str, str]:
"""
Get information about the loaded model.
Returns:
dict: Model information
"""
return {
'provider': self.provider_name,
'model_name': self.model_name,
'model_loaded': str(self._model_loaded),
'supported_language_count': str(len(self.LANGUAGE_MAPPINGS)),
'max_chunk_length': str(self.max_chunk_length)
}
def set_model_name(self, model_name: str) -> None:
"""
Set a different NLLB model name.
Args:
model_name: The new model name to use
"""
if model_name != self.model_name:
self.model_name = model_name
self._model_loaded = False
self._tokenizer = None
self._model = None
logger.info(f"Model name changed to: {model_name}")
def clear_model_cache(self) -> None:
"""Clear the loaded model from memory."""
if self._model_loaded:
self._tokenizer = None
self._model = None
self._model_loaded = False
logger.info("NLLB model cache cleared")