lt_space / app /models /translation_model.py
Arsive2's picture
Updated comments
d0d0352
import logging
import os
import re
from typing import Optional
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger = logging.getLogger(__name__)
class TranslationModel:
"""
More efficient translation model that uses smaller models optimized for CPU
"""
def __init__(self, model_cache_dir: str = ".cache/models"):
"""
Initialize the translation model manager.
Args:
model_cache_dir: Directory to cache downloaded models
"""
self.model_cache_dir = model_cache_dir
self.device = self._get_device()
self.opus_mt_models = {} # Cache for loaded OPUS-MT models
self.fallback_model = None
self.fallback_tokenizer = None
self.initialized = False
self.initialization_error = None
# Create cache directory
os.makedirs(model_cache_dir, exist_ok=True)
try:
# Initialize the fallback model (loads when first needed)
logger.info("TranslationModel initialized - models will be loaded on demand")
self.initialized = True
except Exception as e:
self.initialization_error = str(e)
logger.error(f"Failed to initialize translation model: {str(e)}")
def _get_device(self):
"""Get the best available device for model inference."""
if torch.cuda.is_available():
logger.info("Using CUDA GPU for translation")
return torch.device("cuda")
else:
logger.info("Using CPU for translation")
return torch.device("cpu")
def _get_opus_mt_model_name(self, source_lang_code: str, target_lang_code: str) -> Optional[str]:
"""Get the appropriate OPUS-MT model name for the language pair."""
lang_code_mapping = {
'zh': 'zh',
'en': 'en', # unchanged
'ar': 'ar',
'fr': 'fr',
'de': 'de',
'ru': 'ru',
'pt': 'pt',
'es': 'es', # unchanged
'it': 'it',
'nl': 'nl',
'pl': 'pl',
'ja': 'ja',
'ko': 'ko',
}
source = lang_code_mapping.get(source_lang_code, source_lang_code)
target = lang_code_mapping.get(target_lang_code, target_lang_code)
# Try direct model first
model_name = f"Helsinki-NLP/opus-mt-{source}-{target}"
return model_name
def _load_opus_mt_model(self, source_lang_code: str, target_lang_code: str):
"""Load an OPUS-MT model for the specific language pair."""
model_name = self._get_opus_mt_model_name(source_lang_code, target_lang_code)
# Check if model already loaded
key = f"{source_lang_code}-{target_lang_code}"
if key in self.opus_mt_models:
return self.opus_mt_models[key]
try:
logger.info(f"Loading OPUS-MT model: {model_name}")
# Load with half precision to save memory on CPU
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
cache_dir=self.model_cache_dir,
low_cpu_mem_usage=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir)
model.to(self.device)
logger.info(f"OPUS-MT model loaded successfully: {model_name}")
# Cache the model
self.opus_mt_models[key] = (model, tokenizer)
return model, tokenizer
except Exception as e:
logger.warning(f"Could not load OPUS-MT model {model_name}: {str(e)}")
return None
def _load_fallback_model(self):
"""Load the fallback NLLB-200 model for language pairs without OPUS-MT models."""
if self.fallback_model is not None:
return
try:
# Use the small distilled version for efficiency on CPU
model_name = "facebook/nllb-200-distilled-600M"
logger.info(f"Loading fallback model: {model_name}")
self.fallback_model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
cache_dir=self.model_cache_dir,
low_cpu_mem_usage=True
)
self.fallback_tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir)
self.fallback_model.to(self.device)
logger.info(f"Fallback model loaded successfully: {model_name}")
except Exception as e:
logger.error(f"Error loading fallback model: {str(e)}")
raise
def translate(self, text: str, source_lang_code: str, target_lang_code: str) -> str:
"""
Translate text from source language to target language.
Args:
text: Text to translate
source_lang_code: Source language code
target_lang_code: Target language code
Returns:
Translated text
"""
try:
if not self.initialized:
raise ValueError(f"Translation model not properly initialized: {self.initialization_error}")
# Try to use OPUS-MT model first (faster and often better quality)
opus_mt_result = self._load_opus_mt_model(source_lang_code, target_lang_code)
if opus_mt_result:
model, tokenizer = opus_mt_result
inputs = tokenizer(text, return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True)
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
logger.info(f"Translation completed using OPUS-MT model")
else:
# Fall back to NLLB model
logger.info(f"No OPUS-MT model available for {source_lang_code}-{target_lang_code}, using fallback model")
self._load_fallback_model()
# NLLB uses a specific format for inputs
tokenizer = self.fallback_tokenizer
model = self.fallback_model
# Prepare input with NLLB format
inputs = tokenizer(text, return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# NLLB language codes are like "eng_Latn", "fra_Latn", etc.
nllb_source = _get_nllb_code(source_lang_code)
nllb_target = _get_nllb_code(target_lang_code)
# Force decoder to start with target language token
forced_bos_token_id = tokenizer.lang_code_to_id[nllb_target]
with torch.no_grad():
outputs = model.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_length=512,
num_beams=4,
early_stopping=True
)
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
logger.info(f"Translation completed using fallback NLLB model")
# Clean up the output
return re.sub(r'\s+', ' ', translated_text).strip()
except Exception as e:
logger.error(f"Translation error: {str(e)}")
raise
def _get_nllb_code(lang_code: str) -> str:
"""Convert ISO language code to NLLB language code format."""
# Mapping for common languages
nllb_mapping = {
'en': 'eng_Latn',
'fr': 'fra_Latn',
'es': 'spa_Latn',
'de': 'deu_Latn',
'it': 'ita_Latn',
'pt': 'por_Latn',
'nl': 'nld_Latn',
'ru': 'rus_Cyrl',
'zh': 'zho_Hans',
'ar': 'ara_Arab',
'hi': 'hin_Deva',
'ja': 'jpn_Jpan',
'ko': 'kor_Hang',
}
return nllb_mapping.get(lang_code, f"{lang_code}_Latn")