Spaces:
Sleeping
Sleeping
| """NLLB translation provider implementation.""" | |
| import logging | |
| from typing import Dict, List, Optional | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from ..base.translation_provider_base import TranslationProviderBase | |
| from ...domain.exceptions import TranslationFailedException | |
| logger = logging.getLogger(__name__) | |
| class NLLBTranslationProvider(TranslationProviderBase): | |
| """NLLB-200-3.3B translation provider implementation.""" | |
| # NLLB language code mappings | |
| LANGUAGE_MAPPINGS = { | |
| 'en': 'eng_Latn', | |
| 'zh': 'zho_Hans', | |
| 'zh-cn': 'zho_Hans', | |
| 'zh-tw': 'zho_Hant' | |
| } | |
| def __init__(self, model_name: str = "facebook/nllb-200-3.3B", max_chunk_length: int = 1000): | |
| """ | |
| Initialize NLLB translation provider. | |
| Args: | |
| model_name: The NLLB model name to use | |
| max_chunk_length: Maximum length for text chunks | |
| """ | |
| # Build supported languages mapping | |
| supported_languages = {} | |
| for lang_code in self.LANGUAGE_MAPPINGS.keys(): | |
| # For simplicity, assume all languages can translate to all other languages | |
| # In practice, you might want to be more specific about supported pairs | |
| supported_languages[lang_code] = [ | |
| target for target in self.LANGUAGE_MAPPINGS.keys() | |
| if target != lang_code | |
| ] | |
| super().__init__( | |
| provider_name="NLLB-200-3.3B", | |
| supported_languages=supported_languages | |
| ) | |
| self.model_name = model_name | |
| self.max_chunk_length = max_chunk_length | |
| self._tokenizer: Optional[AutoTokenizer] = None | |
| self._model: Optional[AutoModelForSeq2SeqLM] = None | |
| self._model_loaded = False | |
| def _translate_chunk(self, text: str, source_language: str, target_language: str) -> str: | |
| """ | |
| Translate a single chunk of text using NLLB model. | |
| Args: | |
| text: The text chunk to translate | |
| source_language: Source language code | |
| target_language: Target language code | |
| Returns: | |
| str: The translated text chunk | |
| """ | |
| try: | |
| # Ensure model is loaded | |
| self._ensure_model_loaded() | |
| # Map language codes to NLLB format | |
| source_nllb = self._map_language_code(source_language) | |
| target_nllb = self._map_language_code(target_language) | |
| logger.info(f"Translating chunk from {source_nllb} to {target_nllb}") | |
| # Tokenize with source language specification | |
| inputs = self._tokenizer( | |
| text, | |
| return_tensors="pt", | |
| max_length=1024, | |
| truncation=True | |
| ) | |
| # Generate translation with target language specification | |
| outputs = self._model.generate( | |
| **inputs, | |
| forced_bos_token_id=self._tokenizer.convert_tokens_to_ids(target_nllb), | |
| max_new_tokens=1024, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| # Decode the translation | |
| translated = self._tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Post-process the translation | |
| translated = self._postprocess_text(translated) | |
| logger.info(f"Chunk translation completed: {len(text)} -> {len(translated)} chars") | |
| return translated | |
| except Exception as e: | |
| self._handle_provider_error(e, "chunk translation") | |
| def _ensure_model_loaded(self) -> None: | |
| """Ensure the NLLB model and tokenizer are loaded.""" | |
| if self._model_loaded: | |
| return | |
| try: | |
| logger.info(f"Loading NLLB model: {self.model_name}") | |
| # Load tokenizer | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| src_lang="eng_Latn" # Default source language | |
| ) | |
| # Load model | |
| self._model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | |
| self._model_loaded = True | |
| logger.info("NLLB model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load NLLB model: {str(e)}") | |
| raise TranslationFailedException(f"Failed to load NLLB model: {str(e)}") from e | |
| def _map_language_code(self, language_code: str) -> str: | |
| """ | |
| Map standard language code to NLLB format. | |
| Args: | |
| language_code: Standard language code (e.g., 'en', 'zh') | |
| Returns: | |
| str: NLLB language code (e.g., 'eng_Latn', 'zho_Hans') | |
| """ | |
| # Normalize language code to lowercase | |
| normalized_code = language_code.lower() | |
| # Check direct mapping | |
| if normalized_code in self.LANGUAGE_MAPPINGS: | |
| return self.LANGUAGE_MAPPINGS[normalized_code] | |
| # Handle common variations | |
| if normalized_code.startswith('zh'): | |
| if 'tw' in normalized_code or 'hant' in normalized_code or 'traditional' in normalized_code: | |
| return 'zho_Hant' | |
| else: | |
| return 'zho_Hans' | |
| # Default fallback for unknown codes | |
| logger.warning(f"Unknown language code: {language_code}, defaulting to English") | |
| return 'eng_Latn' | |
| def is_available(self) -> bool: | |
| """ | |
| Check if the NLLB translation provider is available. | |
| Returns: | |
| bool: True if provider is available, False otherwise | |
| """ | |
| try: | |
| # Try to import required dependencies | |
| import transformers | |
| import torch | |
| # Check if we can load the tokenizer (lightweight check) | |
| if not self._model_loaded: | |
| try: | |
| test_tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| src_lang="eng_Latn" | |
| ) | |
| return True | |
| except Exception as e: | |
| logger.warning(f"NLLB model not available: {str(e)}") | |
| return False | |
| else: | |
| return True | |
| except ImportError as e: | |
| logger.warning(f"NLLB dependencies not available: {str(e)}") | |
| return False | |
| def get_supported_languages(self) -> Dict[str, List[str]]: | |
| """ | |
| Get supported language pairs for NLLB provider. | |
| Returns: | |
| dict: Mapping of source languages to supported target languages | |
| """ | |
| return self.supported_languages.copy() | |
| def get_model_info(self) -> Dict[str, str]: | |
| """ | |
| Get information about the loaded model. | |
| Returns: | |
| dict: Model information | |
| """ | |
| return { | |
| 'provider': self.provider_name, | |
| 'model_name': self.model_name, | |
| 'model_loaded': str(self._model_loaded), | |
| 'supported_language_count': str(len(self.LANGUAGE_MAPPINGS)), | |
| 'max_chunk_length': str(self.max_chunk_length) | |
| } | |
| def set_model_name(self, model_name: str) -> None: | |
| """ | |
| Set a different NLLB model name. | |
| Args: | |
| model_name: The new model name to use | |
| """ | |
| if model_name != self.model_name: | |
| self.model_name = model_name | |
| self._model_loaded = False | |
| self._tokenizer = None | |
| self._model = None | |
| logger.info(f"Model name changed to: {model_name}") | |
| def clear_model_cache(self) -> None: | |
| """Clear the loaded model from memory.""" | |
| if self._model_loaded: | |
| self._tokenizer = None | |
| self._model = None | |
| self._model_loaded = False | |
| logger.info("NLLB model cache cleared") |