MedAI_Processing / vi /translator.py
LiamKhoaLe's picture
Force translation with Llama
e138b0e
"""
Vietnamese Translator using Helsinki-NLP/opus-mt-en-vi model
"""
import os
import logging
from typing import List, Dict, Any, Optional, Union
from transformers import MarianMTModel, MarianTokenizer
import torch
logger = logging.getLogger(__name__)
class VietnameseTranslator:
"""
Vietnamese translator using LLM models (NVIDIA/Gemini) with Opus as fallback.
This class handles translation from English to Vietnamese using LLM models
for better quality, with Opus model as fallback.
"""
def __init__(self, model_name: Optional[str] = None, device: Optional[str] = None, paraphraser=None):
"""
Initialize the Vietnamese translator.
Args:
model_name: Hugging Face model name for fallback. Defaults to EN_VI env var or Helsinki-NLP/opus-mt-en-vi
device: Device to run the fallback model on ('cpu', 'cuda', 'auto'). Defaults to 'auto'
paraphraser: Paraphraser instance with LLM models for primary translation
"""
self.model_name = model_name or os.getenv("EN_VI", "Helsinki-NLP/opus-mt-en-vi")
self.device = self._get_device(device)
self.model = None
self.tokenizer = None
self._is_loaded = False
self._stats = {"total_translations": 0, "successful_translations": 0, "failed_translations": 0}
self.paraphraser = paraphraser # LLM-based translator
logger.info(f"VietnameseTranslator initialized with LLM models + Opus fallback: {self.model_name}")
logger.info(f"Using device: {self.device}")
def _get_device(self, device: Optional[str]) -> str:
"""Determine the best device to use for the model."""
if device:
return device
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
def load_model(self) -> None:
"""Load the translation model and tokenizer."""
if self._is_loaded:
logger.debug("Model already loaded, skipping...")
return
try:
logger.info(f"Loading translation model: {self.model_name}")
logger.info(f"Loading on device: {self.device}")
# Set up cache directory
cache_dir = os.getenv("HF_HOME", os.path.abspath("cache/huggingface"))
os.makedirs(cache_dir, exist_ok=True)
# Load tokenizer
self.tokenizer = MarianTokenizer.from_pretrained(
self.model_name,
cache_dir=cache_dir
)
# Load model
self.model = MarianMTModel.from_pretrained(
self.model_name,
cache_dir=cache_dir
)
# Move model to device
self.model = self.model.to(self.device)
self.model.eval()
self._is_loaded = True
logger.info("✅ Translation model loaded successfully")
except Exception as e:
logger.error(f"❌ Failed to load translation model: {e}")
raise RuntimeError(f"Failed to load translation model: {e}")
def translate_text(self, text: str) -> str:
"""
Translate a single text from English to Vietnamese using LLM models first, Opus as fallback.
Args:
text: English text to translate
Returns:
Translated Vietnamese text
"""
if not text or not text.strip():
return text
try:
self._stats["total_translations"] += 1
# Try LLM-based translation first (NVIDIA/Gemini)
if self.paraphraser:
try:
translated = self.paraphraser.translate(text, target_lang="vi")
if translated and translated.strip() and translated.strip() != text.strip():
logger.debug(f"LLM Translation result: '{text[:50]}...' -> '{translated[:50]}...'")
self._stats["successful_translations"] += 1
return translated.strip()
else:
logger.debug("LLM translation failed or returned identical text, trying Opus fallback")
except Exception as e:
logger.debug(f"LLM translation failed: {e}, trying Opus fallback")
# Fallback to Opus model
if not self._is_loaded:
self.load_model()
# Prepare input with target language token
input_text = f">>vie<< {text.strip()}"
# Tokenize
inputs = self.tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
# Translate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=512,
num_beams=4,
early_stopping=True,
do_sample=False
)
# Decode
translated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.debug(f"Opus Translation result: '{text[:50]}...' -> '{translated[:50]}...'")
logger.debug(f"Are original and translated the same? {text.strip() == translated.strip()}")
# Track success
self._stats["successful_translations"] += 1
return translated.strip()
except Exception as e:
logger.error(f"Translation failed for text: '{text[:100]}...' - Error: {e}")
self._stats["failed_translations"] += 1
# Return original text if translation fails
return text
def translate_batch(self, texts: List[str], batch_size: int = 8) -> List[str]:
"""
Translate a batch of texts from English to Vietnamese.
Args:
texts: List of English texts to translate
batch_size: Number of texts to process in each batch
Returns:
List of translated Vietnamese texts
"""
if not self._is_loaded:
self.load_model()
if not texts:
return []
results = []
try:
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
logger.debug(f"Processing batch {i//batch_size + 1}/{(len(texts) + batch_size - 1)//batch_size}")
# Prepare batch with target language tokens
batch_inputs = [f">>vie<< {text.strip()}" for text in batch]
# Tokenize batch
inputs = self.tokenizer(
batch_inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
# Translate batch
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=512,
num_beams=4,
early_stopping=True,
do_sample=False
)
# Decode batch
batch_translations = [
self.tokenizer.decode(output, skip_special_tokens=True).strip()
for output in outputs
]
results.extend(batch_translations)
except Exception as e:
logger.error(f"Batch translation failed: {e}")
# Return original texts if translation fails
results = texts
logger.info(f"Translated {len(texts)} texts successfully")
return results
def translate_dict(self, data: Dict[str, Any], text_fields: List[str]) -> Dict[str, Any]:
"""
Translate specific text fields in a dictionary from English to Vietnamese.
Args:
data: Dictionary containing the data
text_fields: List of field names to translate
Returns:
Dictionary with translated text fields
"""
if not self._is_loaded:
self.load_model()
result = data.copy()
for field in text_fields:
if field in data and isinstance(data[field], str) and data[field].strip():
try:
result[field] = self.translate_text(data[field])
logger.debug(f"Translated field '{field}': '{data[field][:50]}...' -> '{result[field][:50]}...'")
except Exception as e:
logger.error(f"Failed to translate field '{field}': {e}")
# Keep original text if translation fails
result[field] = data[field]
return result
def translate_list_of_dicts(self, data_list: List[Dict[str, Any]], text_fields: List[str]) -> List[Dict[str, Any]]:
"""
Translate specific text fields in a list of dictionaries.
Args:
data_list: List of dictionaries containing the data
text_fields: List of field names to translate in each dictionary
Returns:
List of dictionaries with translated text fields
"""
if not data_list:
return []
logger.info(f"Translating {len(data_list)} items with fields: {text_fields}")
results = []
for i, data in enumerate(data_list):
try:
translated_data = self.translate_dict(data, text_fields)
results.append(translated_data)
if (i + 1) % 100 == 0:
logger.info(f"Translated {i + 1}/{len(data_list)} items")
except Exception as e:
logger.error(f"Failed to translate item {i}: {e}")
results.append(data) # Keep original data if translation fails
logger.info(f"Completed translation of {len(data_list)} items")
return results
def is_loaded(self) -> bool:
"""Check if the model is loaded."""
return self._is_loaded
def get_model_info(self) -> Dict[str, str]:
"""Get information about the loaded model."""
return {
"model_name": self.model_name,
"device": self.device,
"is_loaded": self._is_loaded
}
def get_stats(self) -> Dict[str, Any]:
"""Get translation statistics."""
return self._stats.copy()
def reset_stats(self) -> None:
"""Reset translation statistics."""
self._stats = {"total_translations": 0, "successful_translations": 0, "failed_translations": 0}