import logging import os import re from typing import Optional import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer logger = logging.getLogger(__name__) class TranslationModel: """ More efficient translation model that uses smaller models optimized for CPU """ def __init__(self, model_cache_dir: str = ".cache/models"): """ Initialize the translation model manager. Args: model_cache_dir: Directory to cache downloaded models """ self.model_cache_dir = model_cache_dir self.device = self._get_device() self.opus_mt_models = {} # Cache for loaded OPUS-MT models self.fallback_model = None self.fallback_tokenizer = None self.initialized = False self.initialization_error = None # Create cache directory os.makedirs(model_cache_dir, exist_ok=True) try: # Initialize the fallback model (loads when first needed) logger.info("TranslationModel initialized - models will be loaded on demand") self.initialized = True except Exception as e: self.initialization_error = str(e) logger.error(f"Failed to initialize translation model: {str(e)}") def _get_device(self): """Get the best available device for model inference.""" if torch.cuda.is_available(): logger.info("Using CUDA GPU for translation") return torch.device("cuda") else: logger.info("Using CPU for translation") return torch.device("cpu") def _get_opus_mt_model_name(self, source_lang_code: str, target_lang_code: str) -> Optional[str]: """Get the appropriate OPUS-MT model name for the language pair.""" lang_code_mapping = { 'zh': 'zh', 'en': 'en', # unchanged 'ar': 'ar', 'fr': 'fr', 'de': 'de', 'ru': 'ru', 'pt': 'pt', 'es': 'es', # unchanged 'it': 'it', 'nl': 'nl', 'pl': 'pl', 'ja': 'ja', 'ko': 'ko', } source = lang_code_mapping.get(source_lang_code, source_lang_code) target = lang_code_mapping.get(target_lang_code, target_lang_code) # Try direct model first model_name = f"Helsinki-NLP/opus-mt-{source}-{target}" return model_name def _load_opus_mt_model(self, source_lang_code: str, target_lang_code: str): """Load an OPUS-MT model for the specific language pair.""" model_name = self._get_opus_mt_model_name(source_lang_code, target_lang_code) # Check if model already loaded key = f"{source_lang_code}-{target_lang_code}" if key in self.opus_mt_models: return self.opus_mt_models[key] try: logger.info(f"Loading OPUS-MT model: {model_name}") # Load with half precision to save memory on CPU model = AutoModelForSeq2SeqLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, cache_dir=self.model_cache_dir, low_cpu_mem_usage=True ) tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir) model.to(self.device) logger.info(f"OPUS-MT model loaded successfully: {model_name}") # Cache the model self.opus_mt_models[key] = (model, tokenizer) return model, tokenizer except Exception as e: logger.warning(f"Could not load OPUS-MT model {model_name}: {str(e)}") return None def _load_fallback_model(self): """Load the fallback NLLB-200 model for language pairs without OPUS-MT models.""" if self.fallback_model is not None: return try: # Use the small distilled version for efficiency on CPU model_name = "facebook/nllb-200-distilled-600M" logger.info(f"Loading fallback model: {model_name}") self.fallback_model = AutoModelForSeq2SeqLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, cache_dir=self.model_cache_dir, low_cpu_mem_usage=True ) self.fallback_tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir) self.fallback_model.to(self.device) logger.info(f"Fallback model loaded successfully: {model_name}") except Exception as e: logger.error(f"Error loading fallback model: {str(e)}") raise def translate(self, text: str, source_lang_code: str, target_lang_code: str) -> str: """ Translate text from source language to target language. Args: text: Text to translate source_lang_code: Source language code target_lang_code: Target language code Returns: Translated text """ try: if not self.initialized: raise ValueError(f"Translation model not properly initialized: {self.initialization_error}") # Try to use OPUS-MT model first (faster and often better quality) opus_mt_result = self._load_opus_mt_model(source_lang_code, target_lang_code) if opus_mt_result: model, tokenizer = opus_mt_result inputs = tokenizer(text, return_tensors="pt", padding=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True) translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] logger.info(f"Translation completed using OPUS-MT model") else: # Fall back to NLLB model logger.info(f"No OPUS-MT model available for {source_lang_code}-{target_lang_code}, using fallback model") self._load_fallback_model() # NLLB uses a specific format for inputs tokenizer = self.fallback_tokenizer model = self.fallback_model # Prepare input with NLLB format inputs = tokenizer(text, return_tensors="pt", padding=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} # NLLB language codes are like "eng_Latn", "fra_Latn", etc. nllb_source = _get_nllb_code(source_lang_code) nllb_target = _get_nllb_code(target_lang_code) # Force decoder to start with target language token forced_bos_token_id = tokenizer.lang_code_to_id[nllb_target] with torch.no_grad(): outputs = model.generate( **inputs, forced_bos_token_id=forced_bos_token_id, max_length=512, num_beams=4, early_stopping=True ) translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] logger.info(f"Translation completed using fallback NLLB model") # Clean up the output return re.sub(r'\s+', ' ', translated_text).strip() except Exception as e: logger.error(f"Translation error: {str(e)}") raise def _get_nllb_code(lang_code: str) -> str: """Convert ISO language code to NLLB language code format.""" # Mapping for common languages nllb_mapping = { 'en': 'eng_Latn', 'fr': 'fra_Latn', 'es': 'spa_Latn', 'de': 'deu_Latn', 'it': 'ita_Latn', 'pt': 'por_Latn', 'nl': 'nld_Latn', 'ru': 'rus_Cyrl', 'zh': 'zho_Hans', 'ar': 'ara_Arab', 'hi': 'hin_Deva', 'ja': 'jpn_Jpan', 'ko': 'kor_Hang', } return nllb_mapping.get(lang_code, f"{lang_code}_Latn")