|
|
import logging |
|
|
import os |
|
|
import re |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class TranslationModel: |
|
|
""" |
|
|
More efficient translation model that uses smaller models optimized for CPU |
|
|
""" |
|
|
|
|
|
def __init__(self, model_cache_dir: str = ".cache/models"): |
|
|
""" |
|
|
Initialize the translation model manager. |
|
|
|
|
|
Args: |
|
|
model_cache_dir: Directory to cache downloaded models |
|
|
""" |
|
|
self.model_cache_dir = model_cache_dir |
|
|
self.device = self._get_device() |
|
|
self.opus_mt_models = {} |
|
|
self.fallback_model = None |
|
|
self.fallback_tokenizer = None |
|
|
self.initialized = False |
|
|
self.initialization_error = None |
|
|
|
|
|
|
|
|
os.makedirs(model_cache_dir, exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info("TranslationModel initialized - models will be loaded on demand") |
|
|
self.initialized = True |
|
|
except Exception as e: |
|
|
self.initialization_error = str(e) |
|
|
logger.error(f"Failed to initialize translation model: {str(e)}") |
|
|
|
|
|
def _get_device(self): |
|
|
"""Get the best available device for model inference.""" |
|
|
if torch.cuda.is_available(): |
|
|
logger.info("Using CUDA GPU for translation") |
|
|
return torch.device("cuda") |
|
|
else: |
|
|
logger.info("Using CPU for translation") |
|
|
return torch.device("cpu") |
|
|
|
|
|
def _get_opus_mt_model_name(self, source_lang_code: str, target_lang_code: str) -> Optional[str]: |
|
|
"""Get the appropriate OPUS-MT model name for the language pair.""" |
|
|
lang_code_mapping = { |
|
|
'zh': 'zh', |
|
|
'en': 'en', |
|
|
'ar': 'ar', |
|
|
'fr': 'fr', |
|
|
'de': 'de', |
|
|
'ru': 'ru', |
|
|
'pt': 'pt', |
|
|
'es': 'es', |
|
|
'it': 'it', |
|
|
'nl': 'nl', |
|
|
'pl': 'pl', |
|
|
'ja': 'ja', |
|
|
'ko': 'ko', |
|
|
} |
|
|
|
|
|
source = lang_code_mapping.get(source_lang_code, source_lang_code) |
|
|
target = lang_code_mapping.get(target_lang_code, target_lang_code) |
|
|
|
|
|
|
|
|
model_name = f"Helsinki-NLP/opus-mt-{source}-{target}" |
|
|
return model_name |
|
|
|
|
|
def _load_opus_mt_model(self, source_lang_code: str, target_lang_code: str): |
|
|
"""Load an OPUS-MT model for the specific language pair.""" |
|
|
model_name = self._get_opus_mt_model_name(source_lang_code, target_lang_code) |
|
|
|
|
|
|
|
|
key = f"{source_lang_code}-{target_lang_code}" |
|
|
if key in self.opus_mt_models: |
|
|
return self.opus_mt_models[key] |
|
|
|
|
|
try: |
|
|
logger.info(f"Loading OPUS-MT model: {model_name}") |
|
|
|
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
cache_dir=self.model_cache_dir, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir) |
|
|
|
|
|
model.to(self.device) |
|
|
logger.info(f"OPUS-MT model loaded successfully: {model_name}") |
|
|
|
|
|
|
|
|
self.opus_mt_models[key] = (model, tokenizer) |
|
|
return model, tokenizer |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Could not load OPUS-MT model {model_name}: {str(e)}") |
|
|
return None |
|
|
|
|
|
def _load_fallback_model(self): |
|
|
"""Load the fallback NLLB-200 model for language pairs without OPUS-MT models.""" |
|
|
if self.fallback_model is not None: |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
model_name = "facebook/nllb-200-distilled-600M" |
|
|
logger.info(f"Loading fallback model: {model_name}") |
|
|
|
|
|
self.fallback_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
cache_dir=self.model_cache_dir, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
self.fallback_tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir) |
|
|
|
|
|
self.fallback_model.to(self.device) |
|
|
logger.info(f"Fallback model loaded successfully: {model_name}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading fallback model: {str(e)}") |
|
|
raise |
|
|
|
|
|
def translate(self, text: str, source_lang_code: str, target_lang_code: str) -> str: |
|
|
""" |
|
|
Translate text from source language to target language. |
|
|
|
|
|
Args: |
|
|
text: Text to translate |
|
|
source_lang_code: Source language code |
|
|
target_lang_code: Target language code |
|
|
|
|
|
Returns: |
|
|
Translated text |
|
|
""" |
|
|
try: |
|
|
if not self.initialized: |
|
|
raise ValueError(f"Translation model not properly initialized: {self.initialization_error}") |
|
|
|
|
|
|
|
|
opus_mt_result = self._load_opus_mt_model(source_lang_code, target_lang_code) |
|
|
|
|
|
if opus_mt_result: |
|
|
model, tokenizer = opus_mt_result |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True) |
|
|
|
|
|
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
logger.info(f"Translation completed using OPUS-MT model") |
|
|
|
|
|
else: |
|
|
|
|
|
logger.info(f"No OPUS-MT model available for {source_lang_code}-{target_lang_code}, using fallback model") |
|
|
self._load_fallback_model() |
|
|
|
|
|
|
|
|
tokenizer = self.fallback_tokenizer |
|
|
model = self.fallback_model |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
nllb_source = _get_nllb_code(source_lang_code) |
|
|
nllb_target = _get_nllb_code(target_lang_code) |
|
|
|
|
|
|
|
|
forced_bos_token_id = tokenizer.lang_code_to_id[nllb_target] |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
forced_bos_token_id=forced_bos_token_id, |
|
|
max_length=512, |
|
|
num_beams=4, |
|
|
early_stopping=True |
|
|
) |
|
|
|
|
|
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
logger.info(f"Translation completed using fallback NLLB model") |
|
|
|
|
|
|
|
|
return re.sub(r'\s+', ' ', translated_text).strip() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Translation error: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _get_nllb_code(lang_code: str) -> str: |
|
|
"""Convert ISO language code to NLLB language code format.""" |
|
|
|
|
|
nllb_mapping = { |
|
|
'en': 'eng_Latn', |
|
|
'fr': 'fra_Latn', |
|
|
'es': 'spa_Latn', |
|
|
'de': 'deu_Latn', |
|
|
'it': 'ita_Latn', |
|
|
'pt': 'por_Latn', |
|
|
'nl': 'nld_Latn', |
|
|
'ru': 'rus_Cyrl', |
|
|
'zh': 'zho_Hans', |
|
|
'ar': 'ara_Arab', |
|
|
'hi': 'hin_Deva', |
|
|
'ja': 'jpn_Jpan', |
|
|
'ko': 'kor_Hang', |
|
|
} |
|
|
|
|
|
return nllb_mapping.get(lang_code, f"{lang_code}_Latn") |