""" Vietnamese Translator using Helsinki-NLP/opus-mt-en-vi model """ import os import logging from typing import List, Dict, Any, Optional, Union from transformers import MarianMTModel, MarianTokenizer import torch logger = logging.getLogger(__name__) class VietnameseTranslator: """ Vietnamese translator using LLM models (NVIDIA/Gemini) with Opus as fallback. This class handles translation from English to Vietnamese using LLM models for better quality, with Opus model as fallback. """ def __init__(self, model_name: Optional[str] = None, device: Optional[str] = None, paraphraser=None): """ Initialize the Vietnamese translator. Args: model_name: Hugging Face model name for fallback. Defaults to EN_VI env var or Helsinki-NLP/opus-mt-en-vi device: Device to run the fallback model on ('cpu', 'cuda', 'auto'). Defaults to 'auto' paraphraser: Paraphraser instance with LLM models for primary translation """ self.model_name = model_name or os.getenv("EN_VI", "Helsinki-NLP/opus-mt-en-vi") self.device = self._get_device(device) self.model = None self.tokenizer = None self._is_loaded = False self._stats = {"total_translations": 0, "successful_translations": 0, "failed_translations": 0} self.paraphraser = paraphraser # LLM-based translator logger.info(f"VietnameseTranslator initialized with LLM models + Opus fallback: {self.model_name}") logger.info(f"Using device: {self.device}") def _get_device(self, device: Optional[str]) -> str: """Determine the best device to use for the model.""" if device: return device if torch.cuda.is_available(): return "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "mps" else: return "cpu" def load_model(self) -> None: """Load the translation model and tokenizer.""" if self._is_loaded: logger.debug("Model already loaded, skipping...") return try: logger.info(f"Loading translation model: {self.model_name}") logger.info(f"Loading on device: {self.device}") # Set up cache directory cache_dir = os.getenv("HF_HOME", os.path.abspath("cache/huggingface")) os.makedirs(cache_dir, exist_ok=True) # Load tokenizer self.tokenizer = MarianTokenizer.from_pretrained( self.model_name, cache_dir=cache_dir ) # Load model self.model = MarianMTModel.from_pretrained( self.model_name, cache_dir=cache_dir ) # Move model to device self.model = self.model.to(self.device) self.model.eval() self._is_loaded = True logger.info("✅ Translation model loaded successfully") except Exception as e: logger.error(f"❌ Failed to load translation model: {e}") raise RuntimeError(f"Failed to load translation model: {e}") def translate_text(self, text: str) -> str: """ Translate a single text from English to Vietnamese using LLM models first, Opus as fallback. Args: text: English text to translate Returns: Translated Vietnamese text """ if not text or not text.strip(): return text try: self._stats["total_translations"] += 1 # Try LLM-based translation first (NVIDIA/Gemini) if self.paraphraser: try: translated = self.paraphraser.translate(text, target_lang="vi") if translated and translated.strip() and translated.strip() != text.strip(): logger.debug(f"LLM Translation result: '{text[:50]}...' -> '{translated[:50]}...'") self._stats["successful_translations"] += 1 return translated.strip() else: logger.debug("LLM translation failed or returned identical text, trying Opus fallback") except Exception as e: logger.debug(f"LLM translation failed: {e}, trying Opus fallback") # Fallback to Opus model if not self._is_loaded: self.load_model() # Prepare input with target language token input_text = f">>vie<< {text.strip()}" # Tokenize inputs = self.tokenizer( input_text, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(self.device) # Translate with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=512, num_beams=4, early_stopping=True, do_sample=False ) # Decode translated = self.tokenizer.decode(outputs[0], skip_special_tokens=True) logger.debug(f"Opus Translation result: '{text[:50]}...' -> '{translated[:50]}...'") logger.debug(f"Are original and translated the same? {text.strip() == translated.strip()}") # Track success self._stats["successful_translations"] += 1 return translated.strip() except Exception as e: logger.error(f"Translation failed for text: '{text[:100]}...' - Error: {e}") self._stats["failed_translations"] += 1 # Return original text if translation fails return text def translate_batch(self, texts: List[str], batch_size: int = 8) -> List[str]: """ Translate a batch of texts from English to Vietnamese. Args: texts: List of English texts to translate batch_size: Number of texts to process in each batch Returns: List of translated Vietnamese texts """ if not self._is_loaded: self.load_model() if not texts: return [] results = [] try: for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] logger.debug(f"Processing batch {i//batch_size + 1}/{(len(texts) + batch_size - 1)//batch_size}") # Prepare batch with target language tokens batch_inputs = [f">>vie<< {text.strip()}" for text in batch] # Tokenize batch inputs = self.tokenizer( batch_inputs, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(self.device) # Translate batch with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=512, num_beams=4, early_stopping=True, do_sample=False ) # Decode batch batch_translations = [ self.tokenizer.decode(output, skip_special_tokens=True).strip() for output in outputs ] results.extend(batch_translations) except Exception as e: logger.error(f"Batch translation failed: {e}") # Return original texts if translation fails results = texts logger.info(f"Translated {len(texts)} texts successfully") return results def translate_dict(self, data: Dict[str, Any], text_fields: List[str]) -> Dict[str, Any]: """ Translate specific text fields in a dictionary from English to Vietnamese. Args: data: Dictionary containing the data text_fields: List of field names to translate Returns: Dictionary with translated text fields """ if not self._is_loaded: self.load_model() result = data.copy() for field in text_fields: if field in data and isinstance(data[field], str) and data[field].strip(): try: result[field] = self.translate_text(data[field]) logger.debug(f"Translated field '{field}': '{data[field][:50]}...' -> '{result[field][:50]}...'") except Exception as e: logger.error(f"Failed to translate field '{field}': {e}") # Keep original text if translation fails result[field] = data[field] return result def translate_list_of_dicts(self, data_list: List[Dict[str, Any]], text_fields: List[str]) -> List[Dict[str, Any]]: """ Translate specific text fields in a list of dictionaries. Args: data_list: List of dictionaries containing the data text_fields: List of field names to translate in each dictionary Returns: List of dictionaries with translated text fields """ if not data_list: return [] logger.info(f"Translating {len(data_list)} items with fields: {text_fields}") results = [] for i, data in enumerate(data_list): try: translated_data = self.translate_dict(data, text_fields) results.append(translated_data) if (i + 1) % 100 == 0: logger.info(f"Translated {i + 1}/{len(data_list)} items") except Exception as e: logger.error(f"Failed to translate item {i}: {e}") results.append(data) # Keep original data if translation fails logger.info(f"Completed translation of {len(data_list)} items") return results def is_loaded(self) -> bool: """Check if the model is loaded.""" return self._is_loaded def get_model_info(self) -> Dict[str, str]: """Get information about the loaded model.""" return { "model_name": self.model_name, "device": self.device, "is_loaded": self._is_loaded } def get_stats(self) -> Dict[str, Any]: """Get translation statistics.""" return self._stats.copy() def reset_stats(self) -> None: """Reset translation statistics.""" self._stats = {"total_translations": 0, "successful_translations": 0, "failed_translations": 0}