| """ |
| Neural Machine Translation Module for Multilingual Audio Intelligence System |
| |
| This module implements state-of-the-art neural machine translation using Helsinki-NLP/Opus-MT |
| models. Designed for efficient CPU-based translation with dynamic model loading and |
| intelligent batching strategies. |
| |
| Key Features: |
| - Dynamic model loading for 100+ language pairs |
| - Helsinki-NLP/Opus-MT models (300MB each) for specific language pairs |
| - Intelligent batching for maximum CPU throughput |
| - Fallback to multilingual models (mBART, M2M-100) for rare languages |
| - Memory-efficient model management with automatic cleanup |
| - Robust error handling and translation confidence scoring |
| - Cache management for frequently used language pairs |
| |
| Models: Helsinki-NLP/opus-mt-* series, Facebook mBART50, M2M-100 |
| Dependencies: transformers, torch, sentencepiece |
| """ |
|
|
| import os |
| import logging |
| import warnings |
| import torch |
| from typing import List, Dict, Optional, Tuple, Union |
| import gc |
| from dataclasses import dataclass |
| from collections import defaultdict |
| import time |
|
|
| try: |
| from transformers import ( |
| MarianMTModel, MarianTokenizer, |
| MBartForConditionalGeneration, MBart50TokenizerFast, |
| M2M100ForConditionalGeneration, M2M100Tokenizer, |
| pipeline |
| ) |
| TRANSFORMERS_AVAILABLE = True |
| except ImportError: |
| TRANSFORMERS_AVAILABLE = False |
| logging.warning("transformers not available. Install with: pip install transformers") |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| warnings.filterwarnings("ignore", category=UserWarning) |
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
| @dataclass |
| class TranslationResult: |
| """ |
| Data class representing a translation result with metadata. |
| |
| Attributes: |
| original_text (str): Original text in source language |
| translated_text (str): Translated text in target language |
| source_language (str): Source language code |
| target_language (str): Target language code |
| confidence (float): Translation confidence score |
| model_used (str): Name of the model used for translation |
| processing_time (float): Time taken for translation in seconds |
| """ |
| original_text: str |
| translated_text: str |
| source_language: str |
| target_language: str |
| confidence: float = 1.0 |
| model_used: str = "unknown" |
| processing_time: float = 0.0 |
| |
| def to_dict(self) -> dict: |
| """Convert to dictionary for JSON serialization.""" |
| return { |
| 'original_text': self.original_text, |
| 'translated_text': self.translated_text, |
| 'source_language': self.source_language, |
| 'target_language': self.target_language, |
| 'confidence': self.confidence, |
| 'model_used': self.model_used, |
| 'processing_time': self.processing_time |
| } |
|
|
|
|
| class NeuralTranslator: |
| """ |
| Advanced neural machine translation with dynamic model loading. |
| |
| Supports 100+ languages through Helsinki-NLP/Opus-MT models with intelligent |
| fallback strategies and efficient memory management. |
| """ |
| |
| def __init__(self, |
| target_language: str = "en", |
| device: Optional[str] = None, |
| cache_size: int = 3, |
| use_multilingual_fallback: bool = True, |
| model_cache_dir: Optional[str] = None): |
| """ |
| Initialize the Neural Translator. |
| |
| Args: |
| target_language (str): Target language code (default: 'en' for English) |
| device (str, optional): Device to run on ('cpu', 'cuda', 'auto') |
| cache_size (int): Maximum number of models to keep in memory |
| use_multilingual_fallback (bool): Use mBART/M2M-100 for unsupported pairs |
| model_cache_dir (str, optional): Directory to cache downloaded models |
| """ |
| self.target_language = target_language |
| self.cache_size = cache_size |
| self.use_multilingual_fallback = use_multilingual_fallback |
| self.model_cache_dir = model_cache_dir |
| |
| |
| if device == 'auto' or device is None: |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| else: |
| self.device = torch.device(device) |
| |
| logger.info(f"Initializing NeuralTranslator: target={target_language}, " |
| f"device={self.device}, cache_size={cache_size}") |
| |
| |
| self.model_cache = {} |
| self.fallback_model = None |
| self.fallback_tokenizer = None |
| self.fallback_model_name = None |
| |
| |
| self.language_mapping = self._get_language_mapping() |
| |
| |
| self._supported_pairs_cache = None |
| |
| |
| if use_multilingual_fallback: |
| self._load_fallback_model() |
| |
| def _get_language_mapping(self) -> Dict[str, str]: |
| """Get mapping of language codes to Helsinki-NLP model codes.""" |
| |
| return { |
| 'en': 'en', 'es': 'es', 'fr': 'fr', 'de': 'de', 'it': 'it', 'pt': 'pt', |
| 'ru': 'ru', 'zh': 'zh', 'ja': 'ja', 'ko': 'ko', 'ar': 'ar', 'hi': 'hi', |
| 'tr': 'tr', 'pl': 'pl', 'nl': 'nl', 'sv': 'sv', 'da': 'da', 'no': 'no', |
| 'fi': 'fi', 'hu': 'hu', 'cs': 'cs', 'sk': 'sk', 'sl': 'sl', 'hr': 'hr', |
| 'bg': 'bg', 'ro': 'ro', 'el': 'el', 'he': 'he', 'th': 'th', 'vi': 'vi', |
| 'id': 'id', 'ms': 'ms', 'tl': 'tl', 'sw': 'sw', 'eu': 'eu', 'ca': 'ca', |
| 'gl': 'gl', 'cy': 'cy', 'ga': 'ga', 'mt': 'mt', 'is': 'is', 'lv': 'lv', |
| 'lt': 'lt', 'et': 'et', 'mk': 'mk', 'sq': 'sq', 'be': 'be', 'uk': 'uk', |
| 'ka': 'ka', 'hy': 'hy', 'az': 'az', 'kk': 'kk', 'ky': 'ky', 'uz': 'uz', |
| 'fa': 'fa', 'ur': 'ur', 'bn': 'bn', 'ta': 'ta', 'te': 'te', 'ml': 'ml', |
| 'kn': 'kn', 'gu': 'gu', 'pa': 'pa', 'mr': 'mr', 'ne': 'ne', 'si': 'si', |
| 'my': 'my', 'km': 'km', 'lo': 'lo', 'mn': 'mn', 'bo': 'bo' |
| } |
| |
| def _load_fallback_model(self): |
| """Load multilingual fallback model (mBART50 or M2M-100).""" |
| try: |
| |
| logger.info("Loading mBART50 multilingual fallback model...") |
| |
| self.fallback_model = MBartForConditionalGeneration.from_pretrained( |
| "facebook/mbart-large-50-many-to-many-mmt", |
| cache_dir=self.model_cache_dir |
| ).to(self.device) |
| |
| self.fallback_tokenizer = MBart50TokenizerFast.from_pretrained( |
| "facebook/mbart-large-50-many-to-many-mmt", |
| cache_dir=self.model_cache_dir |
| ) |
| |
| self.fallback_model_name = "mbart50" |
| logger.info("mBART50 fallback model loaded successfully") |
| |
| except Exception as e: |
| logger.warning(f"Failed to load mBART50: {e}") |
| |
| try: |
| |
| logger.info("Loading M2M-100 multilingual fallback model...") |
| |
| self.fallback_model = M2M100ForConditionalGeneration.from_pretrained( |
| "facebook/m2m100_418M", |
| cache_dir=self.model_cache_dir |
| ).to(self.device) |
| |
| self.fallback_tokenizer = M2M100Tokenizer.from_pretrained( |
| "facebook/m2m100_418M", |
| cache_dir=self.model_cache_dir |
| ) |
| |
| self.fallback_model_name = "m2m100" |
| logger.info("M2M-100 fallback model loaded successfully") |
| |
| except Exception as e2: |
| logger.warning(f"Failed to load M2M-100: {e2}") |
| self.fallback_model = None |
| self.fallback_tokenizer = None |
| self.fallback_model_name = None |
| |
| def translate_text(self, |
| text: str, |
| source_language: str, |
| target_language: Optional[str] = None) -> TranslationResult: |
| """ |
| Translate a single text segment. |
| |
| Args: |
| text (str): Text to translate |
| source_language (str): Source language code |
| target_language (str, optional): Target language code (uses default if None) |
| |
| Returns: |
| TranslationResult: Translation result with metadata |
| """ |
| if not text or not text.strip(): |
| return TranslationResult( |
| original_text=text, |
| translated_text=text, |
| source_language=source_language, |
| target_language=target_language or self.target_language, |
| confidence=0.0, |
| model_used="none", |
| processing_time=0.0 |
| ) |
| |
| target_lang = target_language or self.target_language |
| |
| |
| if source_language == target_lang: |
| return TranslationResult( |
| original_text=text, |
| translated_text=text, |
| source_language=source_language, |
| target_language=target_lang, |
| confidence=1.0, |
| model_used="identity", |
| processing_time=0.0 |
| ) |
| |
| start_time = time.time() |
| |
| try: |
| |
| model_name = self._get_model_name(source_language, target_lang) |
| |
| if model_name: |
| result = self._translate_with_opus_mt( |
| text, source_language, target_lang, model_name |
| ) |
| elif self.fallback_model: |
| result = self._translate_with_fallback( |
| text, source_language, target_lang |
| ) |
| else: |
| |
| result = TranslationResult( |
| original_text=text, |
| translated_text=text, |
| source_language=source_language, |
| target_language=target_lang, |
| confidence=0.0, |
| model_used="unavailable", |
| processing_time=0.0 |
| ) |
| |
| result.processing_time = time.time() - start_time |
| return result |
| |
| except Exception as e: |
| logger.error(f"Translation failed: {e}") |
| return TranslationResult( |
| original_text=text, |
| translated_text=text, |
| source_language=source_language, |
| target_language=target_lang, |
| confidence=0.0, |
| model_used="error", |
| processing_time=time.time() - start_time |
| ) |
| |
| def translate_batch(self, |
| texts: List[str], |
| source_languages: List[str], |
| target_language: Optional[str] = None, |
| batch_size: int = 8) -> List[TranslationResult]: |
| """ |
| Translate multiple texts efficiently using batching. |
| |
| Args: |
| texts (List[str]): List of texts to translate |
| source_languages (List[str]): List of source language codes |
| target_language (str, optional): Target language code |
| batch_size (int): Batch size for processing |
| |
| Returns: |
| List[TranslationResult]: List of translation results |
| """ |
| if len(texts) != len(source_languages): |
| raise ValueError("Number of texts must match number of source languages") |
| |
| target_lang = target_language or self.target_language |
| results = [] |
| |
| |
| language_groups = defaultdict(list) |
| for i, (text, src_lang) in enumerate(zip(texts, source_languages)): |
| if text and text.strip(): |
| language_groups[(src_lang, target_lang)].append((i, text)) |
| |
| |
| for (src_lang, tgt_lang), items in language_groups.items(): |
| if src_lang == tgt_lang: |
| |
| for idx, text in items: |
| results.append((idx, TranslationResult( |
| original_text=text, |
| translated_text=text, |
| source_language=src_lang, |
| target_language=tgt_lang, |
| confidence=1.0, |
| model_used="identity", |
| processing_time=0.0 |
| ))) |
| else: |
| |
| for i in range(0, len(items), batch_size): |
| batch_items = items[i:i + batch_size] |
| batch_texts = [item[1] for item in batch_items] |
| batch_indices = [item[0] for item in batch_items] |
| |
| batch_results = self._translate_batch_same_language( |
| batch_texts, src_lang, tgt_lang |
| ) |
| |
| for idx, result in zip(batch_indices, batch_results): |
| results.append((idx, result)) |
| |
| |
| final_results = [None] * len(texts) |
| for idx, result in results: |
| final_results[idx] = result |
| |
| |
| for i, result in enumerate(final_results): |
| if result is None: |
| final_results[i] = TranslationResult( |
| original_text=texts[i], |
| translated_text=texts[i], |
| source_language=source_languages[i], |
| target_language=target_lang, |
| confidence=0.0, |
| model_used="empty", |
| processing_time=0.0 |
| ) |
| |
| return final_results |
| |
| def _translate_batch_same_language(self, |
| texts: List[str], |
| source_language: str, |
| target_language: str) -> List[TranslationResult]: |
| """Translate a batch of texts from the same source language.""" |
| try: |
| model_name = self._get_model_name(source_language, target_language) |
| |
| if model_name: |
| return self._translate_batch_opus_mt( |
| texts, source_language, target_language, model_name |
| ) |
| elif self.fallback_model: |
| return self._translate_batch_fallback( |
| texts, source_language, target_language |
| ) |
| else: |
| |
| return [ |
| TranslationResult( |
| original_text=text, |
| translated_text=text, |
| source_language=source_language, |
| target_language=target_language, |
| confidence=0.0, |
| model_used="unavailable", |
| processing_time=0.0 |
| ) |
| for text in texts |
| ] |
| |
| except Exception as e: |
| logger.error(f"Batch translation failed: {e}") |
| return [ |
| TranslationResult( |
| original_text=text, |
| translated_text=text, |
| source_language=source_language, |
| target_language=target_language, |
| confidence=0.0, |
| model_used="error", |
| processing_time=0.0 |
| ) |
| for text in texts |
| ] |
| |
| def _get_model_name(self, source_lang: str, target_lang: str) -> Optional[str]: |
| """Get Helsinki-NLP model name for language pair.""" |
| |
| src_mapped = self.language_mapping.get(source_lang, source_lang) |
| tgt_mapped = self.language_mapping.get(target_lang, target_lang) |
| |
| |
| model_patterns = [ |
| f"Helsinki-NLP/opus-mt-{src_mapped}-{tgt_mapped}", |
| f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}", |
| f"Helsinki-NLP/opus-mt-{src_mapped}-{target_lang}", |
| f"Helsinki-NLP/opus-mt-{source_lang}-{tgt_mapped}" |
| ] |
| |
| |
| if target_lang == 'en': |
| |
| group_patterns = [ |
| f"Helsinki-NLP/opus-mt-mul-{target_lang}", |
| f"Helsinki-NLP/opus-mt-roa-{target_lang}", |
| f"Helsinki-NLP/opus-mt-gem-{target_lang}", |
| f"Helsinki-NLP/opus-mt-sla-{target_lang}", |
| ] |
| model_patterns.extend(group_patterns) |
| |
| |
| return model_patterns[0] if model_patterns else None |
| |
| def _load_opus_mt_model(self, model_name: str) -> Tuple[MarianMTModel, MarianTokenizer]: |
| """Load Helsinki-NLP Opus-MT model with caching.""" |
| current_time = time.time() |
| |
| |
| if model_name in self.model_cache: |
| model, tokenizer, _ = self.model_cache[model_name] |
| |
| self.model_cache[model_name] = (model, tokenizer, current_time) |
| logger.debug(f"Using cached model: {model_name}") |
| return model, tokenizer |
| |
| |
| if len(self.model_cache) >= self.cache_size: |
| self._clean_model_cache() |
| |
| try: |
| logger.info(f"Loading model: {model_name}") |
| |
| |
| model = MarianMTModel.from_pretrained( |
| model_name, |
| cache_dir=self.model_cache_dir |
| ).to(self.device) |
| |
| tokenizer = MarianTokenizer.from_pretrained( |
| model_name, |
| cache_dir=self.model_cache_dir |
| ) |
| |
| |
| self.model_cache[model_name] = (model, tokenizer, current_time) |
| logger.info(f"Model loaded and cached: {model_name}") |
| |
| return model, tokenizer |
| |
| except Exception as e: |
| logger.warning(f"Failed to load model {model_name}: {e}") |
| raise |
| |
| def _clean_model_cache(self): |
| """Remove least recently used model from cache.""" |
| if not self.model_cache: |
| return |
| |
| |
| lru_model = min(self.model_cache.items(), key=lambda x: x[1][2]) |
| model_name = lru_model[0] |
| |
| |
| model, tokenizer, _ = self.model_cache.pop(model_name) |
| del model, tokenizer |
| |
| |
| if self.device.type == 'cuda': |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| logger.debug(f"Removed model from cache: {model_name}") |
| |
| def _translate_with_opus_mt(self, |
| text: str, |
| source_language: str, |
| target_language: str, |
| model_name: str) -> TranslationResult: |
| """Translate text using Helsinki-NLP Opus-MT model.""" |
| try: |
| model, tokenizer = self._load_opus_mt_model(model_name) |
| |
| |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_length=512, |
| num_beams=4, |
| early_stopping=True, |
| do_sample=False |
| ) |
| |
| translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| return TranslationResult( |
| original_text=text, |
| translated_text=translated_text, |
| source_language=source_language, |
| target_language=target_language, |
| confidence=0.9, |
| model_used=model_name |
| ) |
| |
| except Exception as e: |
| logger.error(f"Opus-MT translation failed: {e}") |
| raise |
| |
| def _translate_batch_opus_mt(self, |
| texts: List[str], |
| source_language: str, |
| target_language: str, |
| model_name: str) -> List[TranslationResult]: |
| """Translate batch using Helsinki-NLP Opus-MT model.""" |
| try: |
| model, tokenizer = self._load_opus_mt_model(model_name) |
| |
| |
| inputs = tokenizer( |
| texts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=512 |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_length=512, |
| num_beams=4, |
| early_stopping=True, |
| do_sample=False |
| ) |
| |
| |
| translated_texts = [ |
| tokenizer.decode(output, skip_special_tokens=True) |
| for output in outputs |
| ] |
| |
| |
| results = [] |
| for original, translated in zip(texts, translated_texts): |
| results.append(TranslationResult( |
| original_text=original, |
| translated_text=translated, |
| source_language=source_language, |
| target_language=target_language, |
| confidence=0.9, |
| model_used=model_name |
| )) |
| |
| return results |
| |
| except Exception as e: |
| logger.error(f"Opus-MT batch translation failed: {e}") |
| raise |
| |
| def _translate_with_fallback(self, |
| text: str, |
| source_language: str, |
| target_language: str) -> TranslationResult: |
| """Translate using multilingual fallback model.""" |
| try: |
| if self.fallback_model_name == "mbart50": |
| return self._translate_with_mbart50(text, source_language, target_language) |
| elif self.fallback_model_name == "m2m100": |
| return self._translate_with_m2m100(text, source_language, target_language) |
| else: |
| raise ValueError("No fallback model available") |
| |
| except Exception as e: |
| logger.error(f"Fallback translation failed: {e}") |
| raise |
| |
| def _translate_batch_fallback(self, |
| texts: List[str], |
| source_language: str, |
| target_language: str) -> List[TranslationResult]: |
| """Translate batch using multilingual fallback model.""" |
| try: |
| if self.fallback_model_name == "mbart50": |
| return self._translate_batch_mbart50(texts, source_language, target_language) |
| elif self.fallback_model_name == "m2m100": |
| return self._translate_batch_m2m100(texts, source_language, target_language) |
| else: |
| raise ValueError("No fallback model available") |
| |
| except Exception as e: |
| logger.error(f"Fallback batch translation failed: {e}") |
| raise |
| |
| def _translate_with_mbart50(self, |
| text: str, |
| source_language: str, |
| target_language: str) -> TranslationResult: |
| """Translate using mBART50 model.""" |
| |
| self.fallback_tokenizer.src_lang = source_language |
| |
| inputs = self.fallback_tokenizer(text, return_tensors="pt") |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| generated_tokens = self.fallback_model.generate( |
| **inputs, |
| forced_bos_token_id=self.fallback_tokenizer.lang_code_to_id[target_language], |
| max_length=512, |
| num_beams=4, |
| early_stopping=True |
| ) |
| |
| translated_text = self.fallback_tokenizer.batch_decode( |
| generated_tokens, skip_special_tokens=True |
| )[0] |
| |
| return TranslationResult( |
| original_text=text, |
| translated_text=translated_text, |
| source_language=source_language, |
| target_language=target_language, |
| confidence=0.85, |
| model_used="mbart50" |
| ) |
| |
| def _translate_batch_mbart50(self, |
| texts: List[str], |
| source_language: str, |
| target_language: str) -> List[TranslationResult]: |
| """Translate batch using mBART50 model.""" |
| |
| self.fallback_tokenizer.src_lang = source_language |
| |
| inputs = self.fallback_tokenizer( |
| texts, return_tensors="pt", padding=True, truncation=True |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| generated_tokens = self.fallback_model.generate( |
| **inputs, |
| forced_bos_token_id=self.fallback_tokenizer.lang_code_to_id[target_language], |
| max_length=512, |
| num_beams=4, |
| early_stopping=True |
| ) |
| |
| translated_texts = self.fallback_tokenizer.batch_decode( |
| generated_tokens, skip_special_tokens=True |
| ) |
| |
| return [ |
| TranslationResult( |
| original_text=original, |
| translated_text=translated, |
| source_language=source_language, |
| target_language=target_language, |
| confidence=0.85, |
| model_used="mbart50" |
| ) |
| for original, translated in zip(texts, translated_texts) |
| ] |
| |
| def _translate_with_m2m100(self, |
| text: str, |
| source_language: str, |
| target_language: str) -> TranslationResult: |
| """Translate using M2M-100 model.""" |
| self.fallback_tokenizer.src_lang = source_language |
| |
| inputs = self.fallback_tokenizer(text, return_tensors="pt") |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| generated_tokens = self.fallback_model.generate( |
| **inputs, |
| forced_bos_token_id=self.fallback_tokenizer.get_lang_id(target_language), |
| max_length=512, |
| num_beams=4, |
| early_stopping=True |
| ) |
| |
| translated_text = self.fallback_tokenizer.batch_decode( |
| generated_tokens, skip_special_tokens=True |
| )[0] |
| |
| return TranslationResult( |
| original_text=text, |
| translated_text=translated_text, |
| source_language=source_language, |
| target_language=target_language, |
| confidence=0.87, |
| model_used="m2m100" |
| ) |
| |
| def _translate_batch_m2m100(self, |
| texts: List[str], |
| source_language: str, |
| target_language: str) -> List[TranslationResult]: |
| """Translate batch using M2M-100 model.""" |
| self.fallback_tokenizer.src_lang = source_language |
| |
| inputs = self.fallback_tokenizer( |
| texts, return_tensors="pt", padding=True, truncation=True |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| generated_tokens = self.fallback_model.generate( |
| **inputs, |
| forced_bos_token_id=self.fallback_tokenizer.get_lang_id(target_language), |
| max_length=512, |
| num_beams=4, |
| early_stopping=True |
| ) |
| |
| translated_texts = self.fallback_tokenizer.batch_decode( |
| generated_tokens, skip_special_tokens=True |
| ) |
| |
| return [ |
| TranslationResult( |
| original_text=original, |
| translated_text=translated, |
| source_language=source_language, |
| target_language=target_language, |
| confidence=0.87, |
| model_used="m2m100" |
| ) |
| for original, translated in zip(texts, translated_texts) |
| ] |
| |
| def get_supported_languages(self) -> List[str]: |
| """Get list of supported source languages.""" |
| |
| opus_mt_languages = list(self.language_mapping.keys()) |
| |
| |
| mbart_languages = [ |
| 'ar', 'cs', 'de', 'en', 'es', 'et', 'fi', 'fr', 'gu', 'hi', 'it', 'ja', |
| 'kk', 'ko', 'lt', 'lv', 'my', 'ne', 'nl', 'ro', 'ru', 'si', 'tr', 'vi', |
| 'zh', 'af', 'az', 'bn', 'fa', 'he', 'hr', 'id', 'ka', 'km', 'mk', 'ml', |
| 'mn', 'mr', 'pl', 'ps', 'pt', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'uk', |
| 'ur', 'xh', 'gl', 'sl' |
| ] |
| |
| |
| m2m_additional = [ |
| 'am', 'cy', 'is', 'mg', 'mt', 'so', 'zu', 'ha', 'ig', 'yo', 'lg', 'ln', |
| 'rn', 'sn', 'tn', 'ts', 've', 'xh', 'zu' |
| ] |
| |
| all_languages = set(opus_mt_languages + mbart_languages + m2m_additional) |
| return sorted(list(all_languages)) |
| |
| def clear_cache(self): |
| """Clear all cached models to free memory.""" |
| logger.info("Clearing model cache...") |
| |
| for model_name, (model, tokenizer, _) in self.model_cache.items(): |
| del model, tokenizer |
| |
| self.model_cache.clear() |
| |
| if self.device.type == 'cuda': |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| logger.info("Model cache cleared") |
| |
| def get_cache_info(self) -> Dict[str, any]: |
| """Get information about cached models.""" |
| return { |
| 'cached_models': list(self.model_cache.keys()), |
| 'cache_size': len(self.model_cache), |
| 'max_cache_size': self.cache_size, |
| 'fallback_model': self.fallback_model_name, |
| 'device': str(self.device) |
| } |
| |
| def __del__(self): |
| """Cleanup resources when the object is destroyed.""" |
| try: |
| self.clear_cache() |
| except Exception: |
| pass |
|
|
|
|
| |
| def translate_text(text: str, |
| source_language: str, |
| target_language: str = "en", |
| device: Optional[str] = None) -> TranslationResult: |
| """ |
| Convenience function to translate text with default settings. |
| |
| Args: |
| text (str): Text to translate |
| source_language (str): Source language code |
| target_language (str): Target language code (default: 'en') |
| device (str, optional): Device to run on ('cpu', 'cuda', 'auto') |
| |
| Returns: |
| TranslationResult: Translation result |
| |
| Example: |
| >>> # Translate from French to English |
| >>> result = translate_text("Bonjour le monde", "fr", "en") |
| >>> print(result.translated_text) # "Hello world" |
| >>> |
| >>> # Translate from Hindi to English |
| >>> result = translate_text("नमस्ते", "hi", "en") |
| >>> print(result.translated_text) # "Hello" |
| """ |
| translator = NeuralTranslator( |
| target_language=target_language, |
| device=device |
| ) |
| |
| return translator.translate_text(text, source_language, target_language) |
|
|
|
|
| |
| if __name__ == "__main__": |
| import sys |
| import argparse |
| import json |
| |
| def main(): |
| """Command line interface for testing neural translation.""" |
| parser = argparse.ArgumentParser(description="Neural Machine Translation Tool") |
| parser.add_argument("text", help="Text to translate") |
| parser.add_argument("--source-lang", "-s", required=True, |
| help="Source language code") |
| parser.add_argument("--target-lang", "-t", default="en", |
| help="Target language code (default: en)") |
| parser.add_argument("--device", choices=["cpu", "cuda", "auto"], default="auto", |
| help="Device to run on") |
| parser.add_argument("--batch-size", type=int, default=8, |
| help="Batch size for multiple texts") |
| parser.add_argument("--output-format", choices=["json", "text"], |
| default="text", help="Output format") |
| parser.add_argument("--list-languages", action="store_true", |
| help="List supported languages") |
| parser.add_argument("--benchmark", action="store_true", |
| help="Run translation benchmark") |
| parser.add_argument("--verbose", "-v", action="store_true", |
| help="Enable verbose logging") |
| |
| args = parser.parse_args() |
| |
| if args.verbose: |
| logging.getLogger().setLevel(logging.DEBUG) |
| |
| try: |
| translator = NeuralTranslator( |
| target_language=args.target_lang, |
| device=args.device |
| ) |
| |
| if args.list_languages: |
| languages = translator.get_supported_languages() |
| print("Supported languages:") |
| for i, lang in enumerate(languages): |
| print(f"{lang:>4}", end=" ") |
| if (i + 1) % 10 == 0: |
| print() |
| if len(languages) % 10 != 0: |
| print() |
| return |
| |
| if args.benchmark: |
| print("=== TRANSLATION BENCHMARK ===") |
| test_texts = [ |
| "Hello, how are you?", |
| "This is a longer sentence to test translation quality.", |
| "Machine translation has improved significantly." |
| ] |
| |
| start_time = time.time() |
| results = translator.translate_batch( |
| test_texts, |
| [args.source_lang] * len(test_texts), |
| args.target_lang |
| ) |
| total_time = time.time() - start_time |
| |
| print(f"Translated {len(test_texts)} texts in {total_time:.2f}s") |
| print(f"Average time per text: {total_time/len(test_texts):.3f}s") |
| print() |
| |
| |
| result = translator.translate_text( |
| args.text, args.source_lang, args.target_lang |
| ) |
| |
| |
| if args.output_format == "json": |
| print(json.dumps(result.to_dict(), indent=2, ensure_ascii=False)) |
| else: |
| print(f"=== TRANSLATION RESULT ===") |
| print(f"Source ({result.source_language}): {result.original_text}") |
| print(f"Target ({result.target_language}): {result.translated_text}") |
| print(f"Model used: {result.model_used}") |
| print(f"Confidence: {result.confidence:.2f}") |
| print(f"Processing time: {result.processing_time:.3f}s") |
| |
| if args.verbose: |
| cache_info = translator.get_cache_info() |
| print(f"\nCache info: {cache_info}") |
| |
| except Exception as e: |
| print(f"Error: {e}", file=sys.stderr) |
| sys.exit(1) |
| |
| |
| if not TRANSFORMERS_AVAILABLE: |
| print("Warning: transformers not available. Install with: pip install transformers") |
| print("Running in demo mode...") |
| |
| |
| dummy_result = TranslationResult( |
| original_text="Bonjour le monde", |
| translated_text="Hello world", |
| source_language="fr", |
| target_language="en", |
| confidence=0.95, |
| model_used="demo", |
| processing_time=0.123 |
| ) |
| |
| print("\n=== DEMO OUTPUT (transformers not available) ===") |
| print(f"Source (fr): {dummy_result.original_text}") |
| print(f"Target (en): {dummy_result.translated_text}") |
| print(f"Confidence: {dummy_result.confidence:.2f}") |
| else: |
| main() |